Skip to content

v0.20.0: SDXL ControlNets with MultiControlNet, GLIGEN, Tiny Autoencoder, SDXL DreamBooth LoRA in free-tier Colab, and more

Compare
Choose a tag to compare
@sayakpaul sayakpaul released this 17 Aug 08:46
· 1347 commits to main since this release

SDXL ControlNets 馃殌

The 馃Ж聽diffusers team has trained two ControlNets on Stable Diffusion XL (SDXL):

image_grid_controlnet_sdxl

You can find all the SDXL ControlNet checkpoints here, including some smaller ones (5 to 7x smaller).

To know more about how to use these ControlNets to perform inference, check out the respective model cards and the documentation. To train custom SDXL ControlNets, you can try out our training script.

MultiControlNet for SDXL

This release also introduces support for combining multiple ControlNets trained on SDXL and performing inference with them. Refer to the documentation to learn more.

GLIGEN

The GLIGEN model was developed by researchers and engineers from聽University of Wisconsin-Madison, Columbia University, and Microsoft. The聽StableDiffusionGLIGENPipeline聽can generate photorealistic images conditioned on grounding inputs. Along with text and bounding boxes, if input images are given, this pipeline can insert objects described by text at the region defined by bounding boxes. Otherwise, it鈥檒l generate an image described by the caption/prompt and insert objects described by text at the region defined by bounding boxes. It鈥檚 trained on COCO2014D and COCO2014CD datasets, and the model uses a frozen CLIP ViT-L/14 text encoder to condition itself on grounding inputs.

gligen_gif

(GIF from the official website)

Grounded inpainting

import torch
from diffusers import StableDiffusionGLIGENPipeline
from diffusers.utils import load_image

# Insert objects described by text at the region defined by bounding boxes
pipe = StableDiffusionGLIGENPipeline.from_pretrained(
    "masterful/gligen-1-4-inpainting-text-box", variant="fp16", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

input_image = load_image(
    "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gligen/livingroom_modern.png"
)
prompt = "a birthday cake"
boxes = [[0.2676, 0.6088, 0.4773, 0.7183]]
phrases = ["a birthday cake"]

images = pipe(
    prompt=prompt,
    gligen_phrases=phrases,
    gligen_inpaint_image=input_image,
    gligen_boxes=boxes,
    gligen_scheduled_sampling_beta=1,
    output_type="pil",
    num_inference_steps=50,
).images
images[0].save("./gligen-1-4-inpainting-text-box.jpg")

Grounded generation

import torch
from diffusers import StableDiffusionGLIGENPipeline
from diffusers.utils import load_image

# Generate an image described by the prompt and
# insert objects described by text at the region defined by bounding boxes
pipe = StableDiffusionGLIGENPipeline.from_pretrained(
    "masterful/gligen-1-4-generation-text-box", variant="fp16", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

prompt = "a waterfall and a modern high speed train running through the tunnel in a beautiful forest with fall foliage"
boxes = [[0.1387, 0.2051, 0.4277, 0.7090], [0.4980, 0.4355, 0.8516, 0.7266]]
phrases = ["a waterfall", "a modern high speed train running through the tunnel"]

images = pipe(
    prompt=prompt,
    gligen_phrases=phrases,
    gligen_boxes=boxes,
    gligen_scheduled_sampling_beta=1,
    output_type="pil",
    num_inference_steps=50,
).images
images[0].save("./gligen-1-4-generation-text-box.jpg")

Refer to the documentation to learn more.

Thanks to聽@nikhil-masterful for contributing GLIGEN in #4441.

Tiny Autoencoder

@madebyollin trained two Autoencoders (on Stable Diffusion and Stable Diffusion XL, respectively) to dramatically cut down the image decoding time. The effects are especially pronounced when working with larger-resolution images. You can use AutoencoderTiny to take advantage of it.

Here鈥檚 the example usage for Stable Diffusion:

import torch
from diffusers import DiffusionPipeline, AutoencoderTiny

pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16
)
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=torch.float16)
pipe = pipe.to("cuda")

prompt = "slice of delicious New York-style berry cheesecake"
image = pipe(prompt, num_inference_steps=25).images[0]
image.save("cheesecake.png")

Refer to the documentation to learn more. Refer to this material to understand the implications of using this Autoencoder in terms of inference latency and memory footprint.

Fine-tuning Stable Diffusion XL with DreamBooth and LoRA on a free-tier Colab Notebook

Stable Diffusion XL鈥檚 (SDXL) high memory requirements often seem restrictive when it comes to using it for downstream applications. Even if one uses parameter-efficient fine-tuning techniques like LoRA, fine-tuning just the UNet component of SDXL can be quite memory-intensive. So, running it on a free-tier Colab Notebook (that usually has a 16 GB T4 GPU attached) seems impossible.

Now, with better support for gradient checkpointing and other recipes like 8 Bit Adam (via bitsandbytes), it is possible to fine-tune the UNet of SDXL with DreamBooth and LoRA on a free-tier Colab Notebook.

Check out the Colab Notebook to learn more.

Thanks to @ethansmith2000 for improving the gradient checkpointing support in #4474.

Support of push_to_hub for models, schedulers, and pipelines

Our models, schedulers, and pipelines now support an option of push_to_hub via the save_pretrained() and also come with a push_to_hub() method. Below are some examples of usage.

Models

from diffusers import ControlNetModel

controlnet = ControlNetModel(
    block_out_channels=(32, 64),
    layers_per_block=2,
    in_channels=4,
    down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
    cross_attention_dim=32,
    conditioning_embedding_out_channels=(16, 32),
)
controlnet.push_to_hub("my-controlnet-model")
# or controlnet.save_pretrained("my-controlnet-model", push_to_hub=True)

Schedulers

from diffusers import DDIMScheduler

scheduler = DDIMScheduler(
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
)
scheduler.push_to_hub("my-controlnet-scheduler")

Pipelines

from diffusers import (
    UNet2DConditionModel,
    AutoencoderKL,
    DDIMScheduler,
    StableDiffusionPipeline,
)
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTokenizer

unet = UNet2DConditionModel(
    block_out_channels=(32, 64),
    layers_per_block=2,
    sample_size=32,
    in_channels=4,
    out_channels=4,
    down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
    up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
    cross_attention_dim=32,
)

scheduler = DDIMScheduler(
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
)

vae = AutoencoderKL(
    block_out_channels=[32, 64],
    in_channels=3,
    out_channels=3,
    down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
    up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
    latent_channels=4,
)

text_encoder_config = CLIPTextConfig(
    bos_token_id=0,
    eos_token_id=2,
    hidden_size=32,
    intermediate_size=37,
    layer_norm_eps=1e-05,
    num_attention_heads=4,
    num_hidden_layers=5,
    pad_token_id=1,
    vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

components = {
    "unet": unet,
    "scheduler": scheduler,
    "vae": vae,
    "text_encoder": text_encoder,
    "tokenizer": tokenizer,
    "safety_checker": None,
    "feature_extractor": None,
}
pipeline = StableDiffusionPipeline(**components)
pipeline.push_to_hub("my-pipeline")

Refer to the documentation to know more.

Thanks to @Wauplin for his generous and constructive feedback (refer to this #4218) on this feature.

Better support for loading Kohya-trained LoRA checkpoints

Providing seamless support for loading Kohya-trained LoRA checkpoints from diffusers is important for us. This is why we continue to improve our load_lora_weights() method. Check out the documentation to know more about what鈥檚 currently supported and the current limitations.

Thanks to @isidentical for extending their help in improving this support.

Better documentation for prompt weighting

Prompt weighting provides a way to emphasize or de-emphasize certain parts of a prompt, allowing for more control over the generated image. compel provides an easy way to do prompt weighting compatible with diffusers. To this end, we have worked on an improved guide. Check it out here.

Defaulting to serialize with .safetensors

Starting with this release, we will default to using .safetensors as our preferred serialization method. This change is reflected in all the training examples that we officially support.

All commits

Significant community contributions

The following contributors have made significant changes to the library over the last release: