diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 929e1319e41b..b0d951ab852f 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -102,6 +102,8 @@ title: "Latent Diffusion" - local: api/pipelines/latent_diffusion_uncond title: "Unconditional Latent Diffusion" + - local: api/pipelines/paint_by_example + title: "PaintByExample" - local: api/pipelines/pndm title: "PNDM" - local: api/pipelines/score_sde_ve diff --git a/docs/source/api/pipelines/overview.mdx b/docs/source/api/pipelines/overview.mdx index a0a3f3d77dd9..05c8d53adca0 100644 --- a/docs/source/api/pipelines/overview.mdx +++ b/docs/source/api/pipelines/overview.mdx @@ -53,6 +53,7 @@ available a colab notebook to directly try them out. | [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation | | [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image | | [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation | +| [paint_by_example](./api/pipelines/paint_by_example) | [**Paint by Example: Exemplar-based Image Editing with Diffusion Models**](https://arxiv.org/abs/2211.13227) | Image-Guided Image Inpainting | | [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation | | [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | | [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | diff --git a/docs/source/api/pipelines/paint_by_example.mdx b/docs/source/api/pipelines/paint_by_example.mdx new file mode 100644 index 000000000000..e40b3453edf4 --- /dev/null +++ b/docs/source/api/pipelines/paint_by_example.mdx @@ -0,0 +1,73 @@ + + +# PaintByExample + +## Overview + +[Paint by Example: Exemplar-based Image Editing with Diffusion Models](https://arxiv.org/abs/2211.13227) by Binxin Yang, Shuyang Gu, Bo Zhang, Ting Zhang, Xuejin Chen, Xiaoyan Sun, Dong Chen, Fang Wen + +The abstract of the paper is the following: + +*Language-guided image editing has achieved great success recently. In this paper, for the first time, we investigate exemplar-guided image editing for more precise control. We achieve this goal by leveraging self-supervised training to disentangle and re-organize the source image and the exemplar. However, the naive approach will cause obvious fusing artifacts. We carefully analyze it and propose an information bottleneck and strong augmentations to avoid the trivial solution of directly copying and pasting the exemplar image. Meanwhile, to ensure the controllability of the editing process, we design an arbitrary shape mask for the exemplar image and leverage the classifier-free guidance to increase the similarity to the exemplar image. The whole framework involves a single forward of the diffusion model without any iterative optimization. We demonstrate that our method achieves an impressive performance and enables controllable editing on in-the-wild images with high fidelity.* + +The original codebase can be found [here](https://github.com/Fantasy-Studio/Paint-by-Example). + +## Available Pipelines: + +| Pipeline | Tasks | Colab +|---|---|:---:| +| [pipeline_paint_by_example.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py) | *Image-Guided Image Painting* | - | + +## Tips + +- PaintByExample is supported by the official [Fantasy-Studio/Paint-by-Example](https://huggingface.co/Fantasy-Studio/Paint-by-Example) checkpoint. The checkpoint has been warm-started from the [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) and with the objective to inpaint partly masked images conditioned on example / reference images +- To quickly demo *PaintByExample*, please have a look at [this demo](https://huggingface.co/spaces/Fantasy-Studio/Paint-by-Example) +- You can run the following code snippet as an example: + + +```python +# !pip install diffusers transformers + +import PIL +import requests +import torch +from io import BytesIO +from diffusers import DiffusionPipeline + + +def download_image(url): + response = requests.get(url) + return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + +img_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/image/example_1.png" +mask_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/mask/example_1.png" +example_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/reference/example_1.jpg" + +init_image = download_image(img_url).resize((512, 512)) +mask_image = download_image(mask_url).resize((512, 512)) +example_image = download_image(example_url).resize((512, 512)) + +pipe = DiffusionPipeline.from_pretrained( + "Fantasy-Studio/Paint-by-Example", + torch_dtype=torch.float16, +) +pipe = pipe.to("cuda") + +image = pipe(image=init_image, mask_image=mask_image, example_image=example_image).images[0] +image +``` + +## PaintByExamplePipeline +[[autodoc]] pipelines.paint_by_example.pipeline_paint_by_example.PaintByExamplePipeline + - __call__ diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 0cb89a459b86..f76e490893ee 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -43,6 +43,7 @@ available a colab notebook to directly try them out. | [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation | | [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image | | [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation | +| [paint_by_example](./api/pipelines/paint_by_example) | [**Paint by Example: Exemplar-based Image Editing with Diffusion Models**](https://arxiv.org/abs/2211.13227) | Image-Guided Image Inpainting | | [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation | | [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | | [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index 38900cc958b8..c3825e400a71 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -41,8 +41,9 @@ UNet2DConditionModel, ) from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel +from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker -from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer +from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig def shave_segments(path, n_shave_prefix_segments=1): @@ -647,6 +648,73 @@ def convert_ldm_clip_checkpoint(checkpoint): return text_model +def convert_paint_by_example_checkpoint(checkpoint): + config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14") + model = PaintByExampleImageEncoder(config) + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + # load clip vision + model.model.load_state_dict(text_model_dict) + + # load mapper + keys_mapper = { + k[len("cond_stage_model.mapper.res") :]: v + for k, v in checkpoint.items() + if k.startswith("cond_stage_model.mapper") + } + + MAPPING = { + "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], + "attn.c_proj": ["attn1.to_out.0"], + "ln_1": ["norm1"], + "ln_2": ["norm3"], + "mlp.c_fc": ["ff.net.0.proj"], + "mlp.c_proj": ["ff.net.2"], + } + + mapped_weights = {} + for key, value in keys_mapper.items(): + prefix = key[: len("blocks.i")] + suffix = key.split(prefix)[-1].split(".")[-1] + name = key.split(prefix)[-1].split(suffix)[0][1:-1] + mapped_names = MAPPING[name] + + num_splits = len(mapped_names) + for i, mapped_name in enumerate(mapped_names): + new_name = ".".join([prefix, mapped_name, suffix]) + shape = value.shape[0] // num_splits + mapped_weights[new_name] = value[i * shape : (i + 1) * shape] + + model.mapper.load_state_dict(mapped_weights) + + # load final layer norm + model.final_layer_norm.load_state_dict( + { + "bias": checkpoint["cond_stage_model.final_ln.bias"], + "weight": checkpoint["cond_stage_model.final_ln.weight"], + } + ) + + # load final proj + model.proj_out.load_state_dict( + { + "bias": checkpoint["proj_out.bias"], + "weight": checkpoint["proj_out.weight"], + } + ) + + # load uncond vector + model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) + return model + + def convert_open_clip_checkpoint(checkpoint): text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") @@ -676,12 +744,24 @@ def convert_open_clip_checkpoint(checkpoint): type=str, help="The YAML config file corresponding to the original architecture.", ) + parser.add_argument( + "--num_in_channels", + default=None, + type=int, + help="The number of input channels. If `None` number of input channels will be automatically inferred.", + ) parser.add_argument( "--scheduler_type", default="pndm", type=str, help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']", ) + parser.add_argument( + "--pipeline_type", + default=None, + type=str, + help="The pipeline type. If `None` pipeline will be automatically inferred.", + ) parser.add_argument( "--image_size", default=None, @@ -737,6 +817,9 @@ def convert_open_clip_checkpoint(checkpoint): original_config = OmegaConf.load(args.original_config_file) + if args.num_in_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = args.num_in_channels + if ( "parameterization" in original_config["model"]["params"] and original_config["model"]["params"]["parameterization"] == "v" @@ -806,8 +889,11 @@ def convert_open_clip_checkpoint(checkpoint): vae.load_state_dict(converted_vae_checkpoint) # Convert the text model. - text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] - if text_model_type == "FrozenOpenCLIPEmbedder": + model_type = args.pipeline_type + if model_type is None: + model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] + + if model_type == "FrozenOpenCLIPEmbedder": text_model = convert_open_clip_checkpoint(checkpoint) tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer") pipe = StableDiffusionPipeline( @@ -820,7 +906,19 @@ def convert_open_clip_checkpoint(checkpoint): feature_extractor=None, requires_safety_checker=False, ) - elif text_model_type == "FrozenCLIPEmbedder": + elif model_type == "PaintByExample": + vision_model = convert_paint_by_example_checkpoint(checkpoint) + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") + pipe = PaintByExamplePipeline( + vae=vae, + image_encoder=vision_model, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=feature_extractor, + ) + elif model_type == "FrozenCLIPEmbedder": text_model = convert_ldm_clip_checkpoint(checkpoint) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 73b7b8ce2faa..c4dc3e50e424 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -72,6 +72,7 @@ AltDiffusionPipeline, CycleDiffusionPipeline, LDMTextToImagePipeline, + PaintByExamplePipeline, StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 61186a253572..7ac7a263cec4 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -406,6 +406,9 @@ def __init__( ): super().__init__() self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + + # 1. Self-Attn self.attn1 = CrossAttention( query_dim=dim, heads=num_attention_heads, @@ -415,23 +418,28 @@ def __init__( cross_attention_dim=cross_attention_dim if only_cross_attention else None, ) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) - self.attn2 = CrossAttention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - ) # is self-attn if context is none - # layer norms - self.use_ada_layer_norm = num_embeds_ada_norm is not None - if self.use_ada_layer_norm: - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) - self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + # 2. Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + ) # is self-attn if context is none else: - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) + self.attn2 = None + + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + # 3. Feed-forward self.norm3 = nn.LayerNorm(dim) # if xformers is installed try to use memory_efficient_attention by default @@ -481,11 +489,12 @@ def forward(self, hidden_states, context=None, timestep=None): else: hidden_states = self.attn1(norm_hidden_states) + hidden_states - # 2. Cross-Attention - norm_hidden_states = ( - self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) - ) - hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states + if self.attn2 is not None: + # 2. Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states # 3. Feed-forward hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states @@ -666,14 +675,16 @@ def __init__( inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim - if activation_fn == "geglu": - geglu = GEGLU(dim, inner_dim) + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) elif activation_fn == "geglu-approximate": - geglu = ApproximateGELU(dim, inner_dim) + act_fn = ApproximateGELU(dim, inner_dim) self.net = nn.ModuleList([]) # project in - self.net.append(geglu) + self.net.append(act_fn) # project dropout self.net.append(nn.Dropout(dropout)) # project out @@ -685,6 +696,27 @@ def forward(self, hidden_states): return hidden_states +class GELU(nn.Module): + r""" + GELU activation function + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + # feedforward class GEGLU(nn.Module): r""" diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 9555e3346fa3..65394fd52640 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -28,6 +28,7 @@ if is_torch_available() and is_transformers_available(): from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline from .latent_diffusion import LDMTextToImagePipeline + from .paint_by_example import PaintByExamplePipeline from .stable_diffusion import ( CycleDiffusionPipeline, StableDiffusionImageVariationPipeline, diff --git a/src/diffusers/pipelines/paint_by_example/__init__.py b/src/diffusers/pipelines/paint_by_example/__init__.py new file mode 100644 index 000000000000..e234139beba1 --- /dev/null +++ b/src/diffusers/pipelines/paint_by_example/__init__.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np + +import PIL +from PIL import Image + +from ...utils import is_torch_available, is_transformers_available + + +if is_transformers_available() and is_torch_available(): + from .image_encoder import PaintByExampleImageEncoder + from .pipeline_paint_by_example import PaintByExamplePipeline diff --git a/src/diffusers/pipelines/paint_by_example/image_encoder.py b/src/diffusers/pipelines/paint_by_example/image_encoder.py new file mode 100644 index 000000000000..75b81431dbd9 --- /dev/null +++ b/src/diffusers/pipelines/paint_by_example/image_encoder.py @@ -0,0 +1,65 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import nn + +from transformers import CLIPPreTrainedModel, CLIPVisionModel + +from ...models.attention import BasicTransformerBlock +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class PaintByExampleImageEncoder(CLIPPreTrainedModel): + def __init__(self, config, proj_size=768): + super().__init__(config) + self.proj_size = proj_size + + self.model = CLIPVisionModel(config) + self.mapper = PaintByExampleMapper(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size) + self.proj_out = nn.Linear(config.hidden_size, self.proj_size) + + # uncondition for scaling + self.uncond_vector = nn.Parameter(torch.rand((1, 1, self.proj_size))) + + def forward(self, pixel_values): + clip_output = self.model(pixel_values=pixel_values) + latent_states = clip_output.pooler_output + latent_states = self.mapper(latent_states[:, None]) + latent_states = self.final_layer_norm(latent_states) + latent_states = self.proj_out(latent_states) + return latent_states + + +class PaintByExampleMapper(nn.Module): + def __init__(self, config): + super().__init__() + num_layers = (config.num_hidden_layers + 1) // 5 + hid_size = config.hidden_size + num_heads = 1 + self.blocks = nn.ModuleList( + [ + BasicTransformerBlock(hid_size, num_heads, hid_size, activation_fn="gelu", attention_bias=True) + for _ in range(num_layers) + ] + ) + + def forward(self, hidden_states): + for block in self.blocks: + hidden_states = block(hidden_states) + + return hidden_states diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py new file mode 100644 index 000000000000..55842e87a6a8 --- /dev/null +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -0,0 +1,559 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import torch + +import PIL +from diffusers.utils import is_accelerate_available +from transformers import CLIPFeatureExtractor + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import logging +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .image_encoder import PaintByExampleImageEncoder + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def prepare_mask_and_masked_image(image, mask): + """ + Prepares a pair (image, mask) to be consumed by the Paint by Example pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Batched mask + if mask.shape[0] == image.shape[0]: + mask = mask.unsqueeze(1) + else: + mask = mask.unsqueeze(0) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + assert mask.shape[1] == 1, "Mask image must have a single channel" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # paint-by-example inverses the mask + mask = 1 - mask + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + if isinstance(image, PIL.Image.Image): + image = np.array(image.convert("RGB")) + + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + if isinstance(mask, PIL.Image.Image): + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + + mask = mask[None, None] + + # paint-by-example inverses the mask + mask = 1 - mask + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * mask + + return mask, masked_image + + +class PaintByExamplePipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + image_encoder: PaintByExampleImageEncoder, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate + # fix by only offloading self.safety_checker for now + cpu_offload(self.safety_checker.vision_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs + def check_inputs(self, image, height, width, callback_steps): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + # encode the mask image into latents space so we can concatenate it to the latents + masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) + masked_image_latents = 0.18215 * masked_image_latents + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + mask = mask.repeat(batch_size, 1, 1, 1) + masked_image_latents = masked_image_latents.repeat(batch_size, 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + uncond_embeddings = self.image_encoder.uncond_vector + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, 1, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([uncond_embeddings, image_embeddings]) + + return image_embeddings + + @torch.no_grad() + def __call__( + self, + example_image: Union[torch.FloatTensor, PIL.Image.Image], + image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + example_image (`torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): + The exemplar image to guide the image generation. + image (`torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs(example_image, height, width, callback_steps) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input image + image_embeddings = self._encode_image( + example_image, device, num_images_per_prompt, do_classifier_free_guidance + ) + + # 4. Preprocess mask and image + mask, masked_image = prepare_mask_and_masked_image(image, mask_image) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + image_embeddings.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, masked_image_latents, mask], dim=1) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 11. Post-processing + image = self.decode_latents(latents) + + # 12. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) + + # 13. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 2d932d240508..f5a2e55d7e02 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -64,6 +64,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class PaintByExamplePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionImageVariationPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/paint_by_example/__init__.py b/tests/pipelines/paint_by_example/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/paint_by_example/test_paint_by_example.py b/tests/pipelines/paint_by_example/test_paint_by_example.py new file mode 100644 index 000000000000..36da2374ac01 --- /dev/null +++ b/tests/pipelines/paint_by_example/test_paint_by_example.py @@ -0,0 +1,227 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch + +from diffusers import AutoencoderKL, PaintByExamplePipeline, PNDMScheduler, UNet2DConditionModel +from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder +from diffusers.utils import floats_tensor, load_image, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from PIL import Image +from transformers import CLIPVisionConfig + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = PaintByExamplePipeline + + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=9, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + scheduler = PNDMScheduler(skip_prk_steps=True) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + config = CLIPVisionConfig( + hidden_size=32, + projection_dim=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + image_size=32, + patch_size=4, + ) + image_encoder = PaintByExampleImageEncoder(config, proj_size=32) + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "image_encoder": image_encoder, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def convert_to_pt(self, image): + image = np.array(image.convert("RGB")) + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + return image + + def get_dummy_inputs(self, device="cpu", seed=0): + # TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64)) + example_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((32, 32)) + example_image = self.convert_to_pt(example_image) + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "example_image": example_image, + "image": init_image, + "mask_image": mask_image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + } + return inputs + + def test_paint_by_example_inpaint(self): + components = self.get_dummy_components() + + # make sure here that pndm scheduler skips prk + pipe = PaintByExamplePipeline(**components) + pipe = pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs() + output = pipe(**inputs) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4397, 0.5553, 0.3802, 0.5222, 0.5811, 0.4342, 0.494, 0.4577, 0.4428]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_paint_by_example_image_tensor(self): + device = "cpu" + inputs = self.get_dummy_inputs() + inputs.pop("mask_image") + image = self.convert_to_pt(inputs.pop("image")) + mask_image = image.clamp(0, 1) / 2 + + # make sure here that pndm scheduler skips prk + pipe = PaintByExamplePipeline(**self.get_dummy_components()) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + output = pipe(image=image, mask_image=mask_image[:, 0], **inputs) + out_1 = output.images + + image = image.cpu().permute(0, 2, 3, 1)[0] + mask_image = mask_image.cpu().permute(0, 2, 3, 1)[0] + + image = Image.fromarray(np.uint8(image)).convert("RGB") + mask_image = Image.fromarray(np.uint8(mask_image)).convert("RGB") + + output = pipe(**self.get_dummy_inputs()) + out_2 = output.images + + assert out_1.shape == (1, 64, 64, 3) + assert np.abs(out_1.flatten() - out_2.flatten()).max() < 5e-2 + + def test_paint_by_example_inpaint_with_num_images_per_prompt(self): + device = "cpu" + pipe = PaintByExamplePipeline(**self.get_dummy_components()) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs() + + images = pipe(**inputs, num_images_per_prompt=2).images + + # check if the output is a list of 2 images + assert len(images) == 2 + + +@slow +@require_torch_gpu +class PaintByExamplePipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_paint_by_example(self): + # make sure here that pndm scheduler skips prk + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/paint_by_example/dog_in_bucket.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/paint_by_example/mask.png" + ) + example_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/paint_by_example/panda.jpg" + ) + + pipe = PaintByExamplePipeline.from_pretrained("Fantasy-Studio/Paint-by-Example") + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=torch_device).manual_seed(321) + output = pipe( + image=init_image, + mask_image=mask_image, + example_image=example_image, + generator=generator, + guidance_scale=5.0, + num_inference_steps=50, + output_type="np", + ) + + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array( + [0.47455794, 0.47086594, 0.47683704, 0.51024145, 0.5064255, 0.5123164, 0.532502, 0.5328063, 0.5428694] + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index 911ec548b3bd..bdaa5b5c9a99 100755 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -295,7 +295,7 @@ def test_spatial_transformer_default(self): output_slice = attention_scores[0, -1, -3:, -3:] expected_slice = torch.tensor( - [-1.2447, -0.0137, -0.9559, -1.5223, 0.6991, -1.0126, -2.0974, 0.8921, -1.0201], device=torch_device + [-1.9455, -0.0066, -1.3933, -1.5878, 0.5325, -0.6486, -1.8648, 0.7515, -0.9689], device=torch_device ) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) @@ -386,7 +386,7 @@ def test_spatial_transformer_dropout(self): output_slice = attention_scores[0, -1, -3:, -3:] expected_slice = torch.tensor( - [-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091], device=torch_device + [-1.9380, -0.0083, -1.3771, -1.5819, 0.5209, -0.6441, -1.8545, 0.7563, -0.9615], device=torch_device ) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) @@ -417,14 +417,13 @@ def test_spatial_transformer_discrete(self): output_slice = attention_scores[0, -2:, -3:] - expected_slice = torch.tensor([-0.8957, -1.8370, -1.3390, -0.9152, -0.5187, -1.1702], device=torch_device) + expected_slice = torch.tensor([-1.7648, -1.0241, -2.0985, -1.8035, -1.6404, -1.2098], device=torch_device) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) def test_spatial_transformer_default_norm_layers(self): spatial_transformer_block = Transformer2DModel(num_attention_heads=1, attention_head_dim=32, in_channels=32) assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == nn.LayerNorm - assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == nn.LayerNorm assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm def test_spatial_transformer_ada_norm_layers(self): @@ -436,7 +435,6 @@ def test_spatial_transformer_ada_norm_layers(self): ) assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == AdaLayerNorm - assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == AdaLayerNorm assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm def test_spatial_transformer_default_ff_layers(self): diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index f18c939e1504..6565a52c2a6c 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -302,7 +302,7 @@ def test_attention_slicing_forward_pass(self): output_with_slicing = pipe(**inputs)[0] max_diff = np.abs(output_with_slicing - output_without_slicing).max() - self.assertLess(max_diff, 1e-5, "Attention slicing should not affect the inference results") + self.assertLess(max_diff, 1e-3, "Attention slicing should not affect the inference results") @unittest.skipIf( torch_device != "cuda" or not is_accelerate_available(),