Skip to content

v0.6.0: Finetuned Stable Diffusion inpainting

Compare
Choose a tag to compare
@anton-l anton-l released this 19 Oct 15:52
· 3257 commits to main since this release

🎨 Finetuned Stable Diffusion inpainting

The first official stable diffusion checkpoint fine-tuned on inpainting has been released.

You can try it out in the official demo here

or code it up yourself 💻 :

from io import BytesIO

import torch

import PIL
import requests
from diffusers import StableDiffusionInpaintPipeline


def download_image(url):
    response = requests.get(url)
    return PIL.Image.open(BytesIO(response.content)).convert("RGB")


img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
image = download_image(img_url).resize((512, 512))
mask_image = download_image(mask_url).resize((512, 512))

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "runwayml/stable-diffusion-inpainting",
    revision="fp16",
    torch_dtype=torch.float16,
)
pipe.to("cuda")

prompt = "Face of a yellow cat, high resolution, sitting on a park bench"

output = pipe(prompt=prompt, image=image, mask_image=mask_image)
image = output.images[0]

gives:

image mask_image prompt Output
drawing drawing Face of a yellow cat, high resolution, sitting on a park bench => drawing

⚠️ This release deprecates the unsupervised noising-based inpainting pipeline into StableDiffusionInpaintPipelineLegacy.
The new StableDiffusionInpaintPipeline is based on a Stable Diffusion model finetuned for the inpainting task: https://huggingface.co/runwayml/stable-diffusion-inpainting

Note
When loading StableDiffusionInpaintPipeline with a non-finetuned model (i.e. the one saved with diffusers<=0.5.1), the pipeline will default to StableDiffusionInpaintPipelineLegacy, to maintain backward compatibility ✨

from diffusers import StableDiffusionInpaintPipeline

pipe = StableDiffusionInpaintPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")

assert pipe.__class__ .__name__ == "StableDiffusionInpaintPipelineLegacy"

Context:

Why this change? When Stable Diffusion came out ~2 months ago, there were many unofficial in-painting demos using the original v1-4 checkpoint ("CompVis/stable-diffusion-v1-4"). These demos worked reasonably well, so that we integrated an experimental StableDiffusionInpaintPipeline class into diffusers. Now that the official inpainting checkpoint was released: https://github.com/runwayml/stable-diffusion we decided to make this our official pipeline and move the old / hacky one to "StableDiffusionInpaintPipelineLegacy".

🚀 ONNX pipelines for image2image and inpainting

Thanks to the contribution by @zledas (#552) this release supports OnnxStableDiffusionImg2ImgPipeline and OnnxStableDiffusionInpaintPipeline optimized for CPU inference:

from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionInpaintPipeline

img_pipeline = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
)

inpaint_pipeline = OnnxStableDiffusionInpaintPipeline.from_pretrained(
    "runwayml/stable-diffusion-inpainting", revision="onnx", provider="CPUExecutionProvider"
)

🌍 Community Pipelines

Two new community pipelines have been added to diffusers 🔥

Stable Diffusion Interpolation example

Interpolate the latent space of Stable Diffusion between different prompts/seeds.
For more info see stable-diffusion-videos.

For a code example, see Stable Diffusion Interpolation

  • Add Stable Diffusion Interpolation Example by @nateraw in #862

Stable Diffusion Interpolation Mega

One Stable Diffusion Pipeline with all functionalities of Text2Image, Image2Image and Inpainting

For a code example, see Stable Diffusion Mega

📝 Changelog