Skip to content

Commit

Permalink
Add paint by example (#1533)
Browse files Browse the repository at this point in the history
* add paint by example

* mkae loading possibel

* up

* Update src/diffusers/models/attention.py

* up

* finalize weight structure

* make example work

* make it work

* up

* up

* fix

* del

* add

* update

* Apply suggestions from code review

* correct transformer 2d

* finish

* up

* up

* up

* up

* fix

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Apply suggestions from code review

* up

* finish

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
  • Loading branch information
patrickvonplaten and pcuenca committed Dec 7, 2022
1 parent 02d83c9 commit 896c98a
Show file tree
Hide file tree
Showing 16 changed files with 1,121 additions and 34 deletions.
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@
title: "Latent Diffusion"
- local: api/pipelines/latent_diffusion_uncond
title: "Unconditional Latent Diffusion"
- local: api/pipelines/paint_by_example
title: "PaintByExample"
- local: api/pipelines/pndm
title: "PNDM"
- local: api/pipelines/score_sde_ve
Expand Down
1 change: 1 addition & 0 deletions docs/source/api/pipelines/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ available a colab notebook to directly try them out.
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [paint_by_example](./api/pipelines/paint_by_example) | [**Paint by Example: Exemplar-based Image Editing with Diffusion Models**](https://arxiv.org/abs/2211.13227) | Image-Guided Image Inpainting |
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
Expand Down
73 changes: 73 additions & 0 deletions docs/source/api/pipelines/paint_by_example.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# PaintByExample

## Overview

[Paint by Example: Exemplar-based Image Editing with Diffusion Models](https://arxiv.org/abs/2211.13227) by Binxin Yang, Shuyang Gu, Bo Zhang, Ting Zhang, Xuejin Chen, Xiaoyan Sun, Dong Chen, Fang Wen

The abstract of the paper is the following:

*Language-guided image editing has achieved great success recently. In this paper, for the first time, we investigate exemplar-guided image editing for more precise control. We achieve this goal by leveraging self-supervised training to disentangle and re-organize the source image and the exemplar. However, the naive approach will cause obvious fusing artifacts. We carefully analyze it and propose an information bottleneck and strong augmentations to avoid the trivial solution of directly copying and pasting the exemplar image. Meanwhile, to ensure the controllability of the editing process, we design an arbitrary shape mask for the exemplar image and leverage the classifier-free guidance to increase the similarity to the exemplar image. The whole framework involves a single forward of the diffusion model without any iterative optimization. We demonstrate that our method achieves an impressive performance and enables controllable editing on in-the-wild images with high fidelity.*

The original codebase can be found [here](https://github.com/Fantasy-Studio/Paint-by-Example).

## Available Pipelines:

| Pipeline | Tasks | Colab
|---|---|:---:|
| [pipeline_paint_by_example.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py) | *Image-Guided Image Painting* | - |

## Tips

- PaintByExample is supported by the official [Fantasy-Studio/Paint-by-Example](https://huggingface.co/Fantasy-Studio/Paint-by-Example) checkpoint. The checkpoint has been warm-started from the [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) and with the objective to inpaint partly masked images conditioned on example / reference images
- To quickly demo *PaintByExample*, please have a look at [this demo](https://huggingface.co/spaces/Fantasy-Studio/Paint-by-Example)
- You can run the following code snippet as an example:


```python
# !pip install diffusers transformers

import PIL
import requests
import torch
from io import BytesIO
from diffusers import DiffusionPipeline


def download_image(url):
response = requests.get(url)
return PIL.Image.open(BytesIO(response.content)).convert("RGB")


img_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/image/example_1.png"
mask_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/mask/example_1.png"
example_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/reference/example_1.jpg"

init_image = download_image(img_url).resize((512, 512))
mask_image = download_image(mask_url).resize((512, 512))
example_image = download_image(example_url).resize((512, 512))

pipe = DiffusionPipeline.from_pretrained(
"Fantasy-Studio/Paint-by-Example",
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")

image = pipe(image=init_image, mask_image=mask_image, example_image=example_image).images[0]
image
```

## PaintByExamplePipeline
[[autodoc]] pipelines.paint_by_example.pipeline_paint_by_example.PaintByExamplePipeline
- __call__
1 change: 1 addition & 0 deletions docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ available a colab notebook to directly try them out.
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [paint_by_example](./api/pipelines/paint_by_example) | [**Paint by Example: Exemplar-based Image Editing with Diffusion Models**](https://arxiv.org/abs/2211.13227) | Image-Guided Image Inpainting |
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
Expand Down
106 changes: 102 additions & 4 deletions scripts/convert_original_stable_diffusion_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@
UNet2DConditionModel,
)
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig


def shave_segments(path, n_shave_prefix_segments=1):
Expand Down Expand Up @@ -647,6 +648,73 @@ def convert_ldm_clip_checkpoint(checkpoint):
return text_model


def convert_paint_by_example_checkpoint(checkpoint):
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
model = PaintByExampleImageEncoder(config)

keys = list(checkpoint.keys())

text_model_dict = {}

for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]

# load clip vision
model.model.load_state_dict(text_model_dict)

# load mapper
keys_mapper = {
k[len("cond_stage_model.mapper.res") :]: v
for k, v in checkpoint.items()
if k.startswith("cond_stage_model.mapper")
}

MAPPING = {
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
"attn.c_proj": ["attn1.to_out.0"],
"ln_1": ["norm1"],
"ln_2": ["norm3"],
"mlp.c_fc": ["ff.net.0.proj"],
"mlp.c_proj": ["ff.net.2"],
}

mapped_weights = {}
for key, value in keys_mapper.items():
prefix = key[: len("blocks.i")]
suffix = key.split(prefix)[-1].split(".")[-1]
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
mapped_names = MAPPING[name]

num_splits = len(mapped_names)
for i, mapped_name in enumerate(mapped_names):
new_name = ".".join([prefix, mapped_name, suffix])
shape = value.shape[0] // num_splits
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]

model.mapper.load_state_dict(mapped_weights)

# load final layer norm
model.final_layer_norm.load_state_dict(
{
"bias": checkpoint["cond_stage_model.final_ln.bias"],
"weight": checkpoint["cond_stage_model.final_ln.weight"],
}
)

# load final proj
model.proj_out.load_state_dict(
{
"bias": checkpoint["proj_out.bias"],
"weight": checkpoint["proj_out.weight"],
}
)

# load uncond vector
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
return model


def convert_open_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")

Expand Down Expand Up @@ -676,12 +744,24 @@ def convert_open_clip_checkpoint(checkpoint):
type=str,
help="The YAML config file corresponding to the original architecture.",
)
parser.add_argument(
"--num_in_channels",
default=None,
type=int,
help="The number of input channels. If `None` number of input channels will be automatically inferred.",
)
parser.add_argument(
"--scheduler_type",
default="pndm",
type=str,
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']",
)
parser.add_argument(
"--pipeline_type",
default=None,
type=str,
help="The pipeline type. If `None` pipeline will be automatically inferred.",
)
parser.add_argument(
"--image_size",
default=None,
Expand Down Expand Up @@ -737,6 +817,9 @@ def convert_open_clip_checkpoint(checkpoint):

original_config = OmegaConf.load(args.original_config_file)

if args.num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = args.num_in_channels

if (
"parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v"
Expand Down Expand Up @@ -806,8 +889,11 @@ def convert_open_clip_checkpoint(checkpoint):
vae.load_state_dict(converted_vae_checkpoint)

# Convert the text model.
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if text_model_type == "FrozenOpenCLIPEmbedder":
model_type = args.pipeline_type
if model_type is None:
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]

if model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
pipe = StableDiffusionPipeline(
Expand All @@ -820,7 +906,19 @@ def convert_open_clip_checkpoint(checkpoint):
feature_extractor=None,
requires_safety_checker=False,
)
elif text_model_type == "FrozenCLIPEmbedder":
elif model_type == "PaintByExample":
vision_model = convert_paint_by_example_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
pipe = PaintByExamplePipeline(
vae=vae,
image_encoder=vision_model,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor,
)
elif model_type == "FrozenCLIPEmbedder":
text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
AltDiffusionPipeline,
CycleDiffusionPipeline,
LDMTextToImagePipeline,
PaintByExamplePipeline,
StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
Expand Down
80 changes: 56 additions & 24 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,9 @@ def __init__(
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm = num_embeds_ada_norm is not None

# 1. Self-Attn
self.attn1 = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
Expand All @@ -415,23 +418,28 @@ def __init__(
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.attn2 = CrossAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
) # is self-attn if context is none

# layer norms
self.use_ada_layer_norm = num_embeds_ada_norm is not None
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
# 2. Cross-Attn
if cross_attention_dim is not None:
self.attn2 = CrossAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
) # is self-attn if context is none
else:
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.attn2 = None

self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)

if cross_attention_dim is not None:
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
else:
self.norm2 = None

# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim)

# if xformers is installed try to use memory_efficient_attention by default
Expand Down Expand Up @@ -481,11 +489,12 @@ def forward(self, hidden_states, context=None, timestep=None):
else:
hidden_states = self.attn1(norm_hidden_states) + hidden_states

# 2. Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
if self.attn2 is not None:
# 2. Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states

# 3. Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
Expand Down Expand Up @@ -666,14 +675,16 @@ def __init__(
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim

if activation_fn == "geglu":
geglu = GEGLU(dim, inner_dim)
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim)
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim)
elif activation_fn == "geglu-approximate":
geglu = ApproximateGELU(dim, inner_dim)
act_fn = ApproximateGELU(dim, inner_dim)

self.net = nn.ModuleList([])
# project in
self.net.append(geglu)
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
Expand All @@ -685,6 +696,27 @@ def forward(self, hidden_states):
return hidden_states


class GELU(nn.Module):
r"""
GELU activation function
"""

def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)

def gelu(self, gate):
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)

def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.gelu(hidden_states)
return hidden_states


# feedforward
class GEGLU(nn.Module):
r"""
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
if is_torch_available() and is_transformers_available():
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .latent_diffusion import LDMTextToImagePipeline
from .paint_by_example import PaintByExamplePipeline
from .stable_diffusion import (
CycleDiffusionPipeline,
StableDiffusionImageVariationPipeline,
Expand Down

0 comments on commit 896c98a

Please sign in to comment.