Skip to content

v0.25.0: aMUSEd, faster SDXL, interruptable pipelines

Compare
Choose a tag to compare
@sayakpaul sayakpaul released this 27 Dec 13:49
· 614 commits to main since this release

aMUSEd

collage_full

aMUSEd is a lightweight text to image model based off of the聽MUSE聽architecture. aMUSEd is particularly useful in applications that require a lightweight and fast model, such as generating many images quickly at once. aMUSEd is currently a research release.

aMUSEd is a VQVAE token-based transformer that can generate an image in fewer forward passes than many diffusion models. In contrast with MUSE, it uses the smaller text encoder CLIP-L/14 instead of T5-XXL. Due to its small parameter count and few forward pass generation process, amused can generate many images quickly. This benefit is seen particularly at larger batch sizes.

Text-to-image generation

import torch
from diffusers import AmusedPipeline

pipe = AmusedPipeline.from_pretrained(
    "amused/amused-512", variant="fp16", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

prompt = "cowboy"
image = pipe(prompt, generator=torch.manual_seed(8)).images[0]
image.save("text2image_512.png")

Image-to-image generation

import torch
from diffusers import AmusedImg2ImgPipeline
from diffusers.utils import load_image

pipe = AmusedImg2ImgPipeline.from_pretrained(
    "amused/amused-512", variant="fp16", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

prompt = "apple watercolor"
input_image = (
    load_image(
        "https://huggingface.co/amused/amused-512/resolve/main/assets/image2image_256_orig.png"
    )
    .resize((512, 512))
    .convert("RGB")
)

image = pipe(prompt, input_image, strength=0.7, generator=torch.manual_seed(3)).images[0]
image.save("image2image_512.png")

Inpainting

import torch
from diffusers import AmusedInpaintPipeline
from diffusers.utils import load_image
from PIL import Image

pipe = AmusedInpaintPipeline.from_pretrained(
    "amused/amused-512", variant="fp16", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

prompt = "a man with glasses"
input_image = (
    load_image(
        "https://huggingface.co/amused/amused-512/resolve/main/assets/inpainting_256_orig.png"
    )
    .resize((512, 512))
    .convert("RGB")
)
mask = (
    load_image(
        "https://huggingface.co/amused/amused-512/resolve/main/assets/inpainting_256_mask.png"
    )
    .resize((512, 512))
    .convert("L")
)    

image = pipe(prompt, input_image, mask, generator=torch.manual_seed(3)).images[0]
image.save(f"inpainting_512.png")

馃摐聽Docs: https://huggingface.co/docs/diffusers/main/en/api/pipelines/amused

馃洜锔徛燤odels:

Faster SDXL

We鈥檙e excited to present an array of optimization techniques that can be used to accelerate the inference latency of text-to-image diffusion models. All of these can be done in native PyTorch without requiring additional C++ code.

SDXL_Batch_Size__1_Steps__30

These techniques are not specific to Stable Diffusion XL (SDXL) and can be used to improve other text-to-image diffusion models too. Starting from default fp32 precision, we can achieve a 3x speed improvement by applying different PyTorch optimization techniques. We encourage you to check out the detailed docs provided below.

Note: Compared to the default way most people use Diffusers which is fp16 + SDPA, applying all the optimization explained in the blog below yields a 30% speed-up.

馃摐聽Docs: https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion
馃尃 PyTorch blog post: https://pytorch.org/blog/accelerating-generative-ai-3/

Interruptible pipelines

Interrupting the diffusion process is particularly useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.

This callback function should take the following arguments: pipe, i, t, and callback_kwargs (this must be returned). Set the pipeline's _interrupt attribute to True to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.

In this example, the diffusion process is stopped after 10 steps even though num_inference_steps is set to 50.

from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.enable_model_cpu_offload()
num_inference_steps = 50

def interrupt_callback(pipe, i, t, callback_kwargs):
    stop_idx = 10
    if i == stop_idx:
        pipe._interrupt = True

    return callback_kwargs

pipe(
    "A photo of a cat",
    num_inference_steps=num_inference_steps,
    callback_on_step_end=interrupt_callback,
)

馃摐聽Docs: https://huggingface.co/docs/diffusers/main/en/using-diffusers/callback

peft in our LoRA training examples

We incorporated peft in all the officially supported training examples concerning LoRA. This greatly simplifies the code and improves readability. LoRA training hasn't been easier, thanks to peft!

More memory-friendly version of LCM LoRA SDXL training

We incorporated best practices from peft to make LCM LoRA training for SDXL more memory-friendly. As such, you don't have to initialize two UNets (teacher and student) anymore. This version also integrates with the datasets library for quick experimentation. Check out this section for more details.

All commits

Significant community contributions

The following contributors have made significant changes to the library over the last release:

  • @hako-mikan
    • [Community Pipeline] Regional Prompting Pipeline (#6015)
    • [Fix] Fix Regional Prompting Pipeline (#6188)
  • @TonyLianLong
    • LLMGroundedDiffusionPipeline: inherit from DiffusionPipeline and fix peft (#6023)
  • @okotaku
    • [Feature] Support IP-Adapter Plus (#5915)
  • @RuoyiDu
    • [Community Pipeline] DemoFusion: Democratising High-Resolution Image Generation With No $$$ (#6022)
  • @UmerHA
    • Add ControlNet-XS support (#5827)
  • @a-r-r-o-w
    • [Community] AnimateDiff + Controlnet Pipeline (#5928)
    • IP adapter support for most pipelines (#5900)
    • Add missing subclass docs, Fix broken example in SD_safe (#6116)
    • Support img2img and inpaint in lpw-xl (#6114)
  • @Monohydroxides
    • [Community] Add SDE Drag pipeline (#6105)
  • @dg845
    • Clean Up Comments in LCM(-LoRA) Distillation Scripts. (#6145)
    • Change LCM-LoRA README Script Example Learning Rates to 1e-4 (#6304)
    • Add rescale_betas_zero_snr Argument to DDPMScheduler (#6305)
    • Fix LCM distillation bug when creating the guidance scale embeddings using multiple GPUs. (#6279)
  • @markkua
    • [Community Pipeline] Add Marigold Monocular Depth Estimation (#6249)