Controllable Generation: Pix2Pix0, Attend and Excite, SEGA, SAG, ...
🎯 Controlling Generation
There has been much recent work on fine-grained control of diffusion networks!
Diffusers now supports:
- Instruct Pix2Pix
- Pix2Pix 0, more details in docs
- Attend and excite, more details in docs
- Semantic guidance, more details in docs
- Self-attention guidance, more details in docs
- Depth2image
- MultiDiffusion panorama, more details in docs
See our doc on controlling image generation and the individual pipeline docs for more details on the individual methods.
🆙 Latent Upscaler
Latent Upscaler is a diffusion model that is designed explicitly for Stable Diffusion. You can take the generated latent from Stable Diffusion and pass it into the upscaler before decoding with your standard VAE. Or you can take any image, encode it into the latent space, use the upscaler, and decode it. It is incredibly flexible and can work with any SD checkpoints.
Original output image | 2x upscaled output image |
---|---|
The model was developed by Katherine Crowson in collaboration with Stability AI
from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline
import torch
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipeline.to("cuda")
upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained("stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16)
upscaler.to("cuda")
prompt = "a photo of an astronaut high resolution, unreal engine, ultra realistic"
generator = torch.manual_seed(33)
# we stay in latent space! Let's make sure that Stable Diffusion returns the image
# in latent space
low_res_latents = pipeline(prompt, generator=generator, output_type="latent").images
upscaled_image = upscaler(
prompt=prompt,
image=low_res_latents,
num_inference_steps=20,
guidance_scale=0,
generator=generator,
).images[0]
# Let's save the upscaled image under "upscaled_astronaut.png"
upscaled_image.save("astronaut_1024.png")
# as a comparison: Let's also save the low-res image
with torch.no_grad():
image = pipeline.decode_latents(low_res_latents)
image = pipeline.numpy_to_pil(image)[0]
image.save("astronaut_512.png")
⚡ Optimization
In addition to new features and an increasing number of pipelines, diffusers
cares a lot about performance. This release brings a number of optimizations that you can turn on easily.
xFormers
Memory efficient attention, as implemented by xFormers, has been available in diffusers
for some time. The problem was that installing xFormers
could be complicated because there were no official pip
wheels (or they were outdated), and you had to resort to installing from source.
From xFormers 0.0.16
, official pip wheels are now published with every release, so installing and using xFormers is now as simple as these two steps:
pip install xformers
in your terminal.pipe.enable_xformers_memory_efficient_attention()
in your code to opt-in in your pipelines.
These actions will unlock dramatic memory savings, and usually faster inference too!
See more details in the documentation.
Torch 2.0
Speaking of memory-efficient attention, Accelerated PyTorch 2.0 Transformers now comes with built-in native support for it! When PyTorch 2.0 is released you'll no longer have to install xFormers
or any third-party package to take advantage of it. In diffusers
we are already preparing for that, and it works out of the box. So, if you happen to be using the latest "nightlies" of PyTorch 2.0 beta, then you're all set – diffusers will use Accelerated PyTorch 2.0 Transformers by default.
In our tests, the built-in PyTorch 2.0 implementation is usually as fast as xFormers', and sometimes even faster. Performance depends on the card you are using and whether you run your code in float16
or float32
, so check our documentation for details.
Coarse-grained CPU offload
Community member @keturn, with whom we have enjoyed thoughtful software design conversations, called our attention to the fact that enabling sequential cpu offloading via enable_sequential_cpu_offload
worked great to save a lot of memory, but made inference much slower.
This is because enable_sequential_cpu_offload()
is optimized for memory, and it recursively works across all the submodules contained in a model, moving them to GPU when they are needed and back to CPU when another submodule needs to run. These cpu-to-gpu-to-cpu transfers happen hundreds of times during the stable diffusion denoising loops, because the UNet runs multiple times and it consists of several PyTorch modules.
This release of diffusers
introduces a coarser enable_model_cpu_offload()
pipeline API, which copies whole models (not modules) to GPU and makes sure they stay there until another model needs to run. The consequences are:
- Less memory savings than
enable_sequential_cpu_offload
, but: - Almost as fast inference as when the pipeline is used without any type of offloading.
Pix2Pix Zero
Remember the CycleGAN days where one would turn a horse into a zebra in an image while keeping the rest of the content almost untouched? Well, that day has arrived but in the context of Diffusion. Pix2Pix Zero allows users to edit a particular image (be it real or generated), targeting a source concept (horse, for example) and replacing it with a target concept (zebra, for example).
Input image | Edited image |
---|---|
Pix2Pix Zero was proposed in Zero-shot Image-to-Image Translation. The StableDiffusionPix2PixZeroPipeline
allows you to
- Edit an image generated from an input prompt
- Provide an input image and edit it
For the latter, it uses the newly introduced DDIMInverseScheduler
to first obtain the inverted noise from the input image and use that in the subsequent generation process.
Both of the use cases leverage the idea of "edit directions", used for steering the generation toward the target concept gradually from the source concept. To know more, we recommend checking out the official documentation.
Attend and excite
Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models. Attend-and-Excite, guides the generative model to modify the cross-attention values during the image synthesis process to generate images that more faithfully depict the input text prompt. It allows creating images that are more semantically faithful with respect to the input text prompts. Thanks to community contributor @evinpinar for leading the charge to add this pipeline!
- Attend and excite 2 by @evinpinar @yiyixuxu #2369
Semantic guidance
Semantic Guidance for Diffusion Models was proposed in SEGA: Instructing Diffusion using Semantic Dimensions and provides strong semantic control over image generation. Small changes to the text prompt usually result in entirely different output images. However, with SEGA, a variety of changes to the image are enabled that can be controlled easily and intuitively and stay true to the original image composition. Thanks to the lead author of SEFA, Manuel (@manuelbrack), who added the pipeline in #2223.
Here is a simple demo:
import torch
from diffusers import SemanticStableDiffusionPipeline
pipe = SemanticStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe = pipe.to("cuda")
out = pipe(
prompt="a photo of the face of a woman",
num_images_per_prompt=1,
guidance_scale=7,
editing_prompt=[
"smiling, smile", # Concepts to apply
"glasses, wearing glasses",
"curls, wavy hair, curly hair",
"beard, full beard, mustache",
],
reverse_editing_direction=[False, False, False, False], # Direction of guidance i.e. increase all concepts
edit_warmup_steps=[10, 10, 10, 10], # Warmup period for each concept
edit_guidance_scale=[4, 5, 5, 5.4], # Guidance scale for each concept
edit_threshold=[
0.99,
0.975,
0.925,
0.96,
], # Threshold for each concept. Threshold equals the percentile of the latent space that will be discarded. I.e. threshold=0.99 uses 1% of the latent dimensions
edit_momentum_scale=0.3, # Momentum scale that will be added to the latent guidance
edit_mom_beta=0.6, # Momentum beta
edit_weights=[1, 1, 1, 1, 1], # Weights of the individual concepts against each other
)
Self-attention guidance
SAG was proposed in Improving Sample Quality of Diffusion Models Using Self-Attention Guidance. SAG works by extracting the intermediate attention map from a diffusion model at every iteration and selects tokens above a certain attention score for masking and blurring to obtain a partially blurred input. Then, the dissimilarity is measured between the predicted noise outputs obtained from feeding the blurred and original input to the diffusion model and this is further leveraged as guidance. With this guidance, the authors observe apparent improvements in a wide range of diffusion models.
import torch
from diffusers import StableDiffusionSAGPipeline
from accelerate.utils import set_seed
pipe = StableDiffusionSAGPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe = pipe.to("cuda")
seed = 8978
prompt = "."
guidance_scale = 7.5
num_images_per_prompt = 1
sag_scale = 1.0
set_seed(seed)
images = pipe(
prompt, num_images_per_prompt=num_images_per_prompt, guidance_scale=guidance_scale, sag_scale=sag_scale
).images
images[0].save("example.png")
SAG was contributed by @SusungHong (lead author of SAG) in #2193.
MultiDiffusion panorama
Proposed in MultiDiffusion: Fusing Diffusion Paths for Controlled Image Generation, it presents a new generation process, "MultiDiffusion", based on an optimization task that binds together multiple diffusion generation processes with a shared set of parameters or constraints.
import torch
from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler
model_ckpt = "stabilityai/stable-diffusion-2-base"
scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
pipe = StableDiffusionPanoramaPipeline.from_pretrained(model_ckpt, scheduler=scheduler, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
prompt = "a photo of the dolomites"
image = pipe(prompt).images[0]
image.save("dolomites.png")
The pipeline was contributed by @omerbt (lead author of MultiDiffusion Panorama) and @sayakpaul in #2393.
Ethical Guidelines
Diffusers is no stranger to the different opinions and perspectives about the challenges that generative technologies bring. Thanks to @giadilli, we have drafted our first Diffusers' Ethical Guidelines with which we hope to initiate a fruitful conversation with the community.
Keras Integration
Many practitioners find it easy to fine-tune the Stable Diffusion models shipped by KerasCV. At the same time, diffusers
provides a lot of options for inference, deployment and optimization. We have made it possible to easily import and use KerasCV Stable Diffusion checkpoints in diffusers
, read more about the process in our new guide.
🕒 UniPC scheduler
UniPC is a new fast scheduler in diffusion town! UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders.
The orginal codebase can be found here. Thanks to @wl-zhao for the great work and integrating UniPC into the diffusers!
🏃 Training: consistent EMA support
As part of 0.13.0 we improved the support for EMA in training. We added a common EMAModel
in diffusers.training_utils
which can be used by all scripts. The EMAModel
is improved to support distributed training,
new methods to easily evaluate the EMA model during training and a consistent way to save and load the EMA model similar to other models in diffusers
.
- Fix EMA for multi-gpu training in the unconditional example by @anton-l, @patil-suraj #1930
- [Utils] Adds store() and restore() methods to EMAModel by @sayakpaul #2302
- Use accelerate save & loading hooks to have better checkpoint structure by @patrickvonplaten #2048
🐶 Ruff & black
We have replaced flake8
with ruff
(much faster), and updated our version of black
. These tools are now in sync with the ones used in transformers
, so the contributing experience is now more consistent for people using both codebases :)
All commits
- [lora] Fix bug with training without validation by @orenwang in #2106
- [Bump version] 0.13.0dev0 & Deprecate
predict_epsilon
by @patrickvonplaten in #2109 - [dreambooth] check the low-precision guard before preparing model by @patil-suraj in #2102
- [textual inversion] Allow validation images by @pcuenca in #2077
- Allow
UNet2DModel
to use arbitrary class embeddings by @pcuenca in #2080 - make scaling factor a config arg of vae/vqvae by @patil-suraj in #1860
- [Import Utils] Fix naming by @patrickvonplaten in #2118
- Fix unable to save_pretrained when using pathlib by @Cyberes in #1972
- fuse attention mask by @williamberman in #2111
- Fix model card of LoRA by @hysts in #2114
- [nit] torch_dtype used twice in doc string by @williamberman in #2126
- [LoRA] Make sure LoRA can be disabled after it's run by @patrickvonplaten in #2128
- remove redundant allow_patterns by @williamberman in #2130
- Allow lora from pipeline by @patrickvonplaten in #2129
- Fix typos in loaders.py by @kuotient in #2137
- Typo fix:
torwards
->towards
by @RahulBhalley in #2134 - Don't call the Hub if
local_files_only
is specifiied by @patrickvonplaten in #2119 - [from_pretrained] only load config one time by @williamberman in #2131
- Adding some
safetensors
docs. by @Narsil in #2122 - Fix typo by @pcuenca in #2138
- fix typo in EMAModel's load_state_dict() by @dasayan05 in #2151
[diffusers-cli]
Fix typo in accelerate and transformers versions by @pcuenca in #2154- [Design philosopy] Create official doc by @patrickvonplaten in #2140
- Section on using LoRA alpha / scale by @pcuenca in #2139
- Don't copy when unwrapping model by @pcuenca in #2166
- Add instance prompt to model card of lora dreambooth example by @hysts in #2112
- [Bug]: fix DDPM scheduler arbitrary infer steps count. by @dudulightricks in #2076
- [examples] Fix CLI argument in the launch script command for text2image with LoRA by @sayakpaul in #2171
- [Breaking change] fix legacy inpaint noise and resize mask tensor by @1lint in #2147
- Use
requests
instead ofwget
inconvert_from_ckpt.py
by @Abhishek-Varma in #2168 - [Docs] Add components to docs by @patrickvonplaten in #2175
- [Docs] remove license by @patrickvonplaten in #2188
- Pass LoRA rank to LoRALinearLayer by @asadm in #2191
- add: guide on kerascv conversion tool. by @sayakpaul in #2169
- Fix a dimension bug in Transform2d by @lmxyy in #2144
- [Loading] Better error message on missing keys by @patrickvonplaten in #2198
- Update xFormers docs by @pcuenca in #2208
- add CITATION.cff by @kashif in #2211
- Create train_dreambooth_inpaint_lora.py by @thedarkzeno in #2205
- Docs: short section on changing the scheduler in Flax by @pcuenca in #2181
- [Bug] scheduling_ddpm: fix variance in the case of learned_range type. by @dudulightricks in #2090
- refactor onnxruntime integration by @prathikr in #2042
- Fix timestep dtype in legacy inpaint by @dymil in #2120
- [nit] negative_prompt typo by @williamberman in #2227
- removes
~
s in favor of full-fledged links. by @sayakpaul in #2229 - [LoRA] Make sure validation works in multi GPU setup by @patrickvonplaten in #2172
- fix: flagged_images implementation by @justinmerrell in #1947
- Hotfix textual inv logging by @isamu-isozaki in #2183
- Fixes LoRAXFormersCrossAttnProcessor by @jorgemcgomes in #2207
- Fix typo in StableDiffusionInpaintPipeline by @hutec in #2197
- [Flax DDPM] Make
key
optional so default pipelines don't fail by @pcuenca in #2176 - Show error when loading safety_checker
from_flax
by @pcuenca in #2187 - Fix k_dpm_2 & k_dpm_2_a on MPS by @psychedelicious in #2241
- Fix a typo: bfloa16 -> bfloat16 by @nickkolok in #2243
- Mention training problems with xFormers 0.0.16 by @pcuenca in #2254
- fix distributed init twice by @Fazziekey in #2252
- Fixes prompt input checks in StableDiffusion img2img pipeline by @jorgemcgomes in #2206
- Create convert_vae_pt_to_diffusers.py by @chavinlo in #2215
- Stable Diffusion Latent Upscaler by @yiyixuxu in #2059
- [Examples] Remove datasets important that is not needed by @patrickvonplaten in #2267
- Make center crop and random flip as args for unconditional image generation by @wfng92 in #2259
- [Tests] Fix slow tests by @patrickvonplaten in #2271
- Fix torchvision.transforms and transforms function naming clash by @wfng92 in #2274
- mps cross-attention hack: don't crash on fp16 by @pcuenca in #2258
- Use
accelerate
save & loading hooks to have better checkpoint structure by @patrickvonplaten in #2048 - Replace flake8 with ruff and update black by @patrickvonplaten in #2279
- Textual inv save log memory by @isamu-isozaki in #2184
- EMA: fix
state_dict()
andload_state_dict()
& addcur_decay_value
by @chenguolin in #2146 - [Examples] Test all examples on CPU by @patrickvonplaten in #2289
- fix pix2pix docs by @patrickvonplaten in #2290
- misc fixes by @williamberman in #2282
- Run same number of DDPM steps in inference as training by @bencevans in #2263
- [LoRA] Freezing the model weights by @erkams in #2245
- Fast CPU tests should also run on main by @patrickvonplaten in #2313
- Correct fast tests by @patrickvonplaten in #2314
- remove ddpm test_full_inference by @williamberman in #2291
- convert ckpt script docstring fixes by @williamberman in #2293
- [Community Pipeline] UnCLIP Text Interpolation Pipeline by @Abhinay1997 in #2257
- [Tests] Refactor push tests by @patrickvonplaten in #2329
- Add ethical guidelines by @giadilli in #2330
- Fix running LoRA with xformers by @bddppq in #2286
- Fix typo in load_pipeline_from_original_stable_diffusion_ckpt() method by @p1atdev in #2320
- [Docs] Fix ethical guidelines docs by @patrickvonplaten in #2333
- [Versatile Diffusion] Fix tests by @patrickvonplaten in #2336
- [Latent Upscaling] Remove unused noise by @patrickvonplaten in #2298
- [Tests] Remove unnecessary tests by @patrickvonplaten in #2337
- karlo image variation use kakaobrain upload by @williamberman in #2338
- github issue forum link by @williamberman in #2335
- dreambooth checkpointing tests and docs by @williamberman in #2339
- unet check length inputs by @williamberman in #2327
- unCLIP variant by @williamberman in #2297
- Log Unconditional Image Generation Samples to W&B by @bencevans in #2287
- Fix callback type hints - no optional function argument by @patrickvonplaten in #2357
- [Docs] initial docs about KarrasDiffusionSchedulers by @kashif in #2349
- KarrasDiffusionSchedulers type note by @williamberman in #2365
- [Tests] Add MPS skip decorator by @patrickvonplaten in #2362
- Funky spacing issue by @meg-huggingface in #2368
- schedulers add glide noising schedule by @williamberman in #2347
- add total number checkpoints to training scripts by @williamberman in #2367
- checkpointing_steps_total_limit->checkpoints_total_limit by @williamberman in #2374
- Fix 3-way merging with the checkpoint_merger community pipeline by @damian0815 in #2355
- [Variant] Add "variant" as input kwarg so to have better UX when downloading no_ema or fp16 weights by @patrickvonplaten in #2305
- [Pipelines] Adds pix2pix zero by @sayakpaul in #2334
- Add Self-Attention-Guided (SAG) Stable Diffusion pipeline by @SusungHong in #2193
- [SchedulingPNDM ] reset cur_model_output after each call by @patil-suraj in #2376
- train_text_to_image EMAModel saving by @williamberman in #2341
- [Utils] Adds
store()
andrestore()
methods to EMAModel by @sayakpaul in #2302 enable_model_cpu_offload
by @pcuenca in #2285- add the UniPC scheduler by @wl-zhao in #2373
- Replace torch.concat calls by torch.cat by @fxmarty in #2378
- Make diffusers importable with transformers < 4.26 by @pcuenca in #2380
- [Examples] Make sure EMA works with any device by @patrickvonplaten in #2382
- [Dummy imports] Add missing if else statements for SD] by @patrickvonplaten in #2381
- Attend and excite 2 by @yiyixuxu in #2369
- [Pix2Pix0] Add utility function to get edit vector by @patrickvonplaten in #2383
- Revert "[Pix2Pix0] Add utility function to get edit vector" by @patrickvonplaten in #2384
- Fix stable diffusion onnx pipeline error when batch_size > 1 by @tianleiwu in #2366
- [Docs] Fix UniPC docs by @wl-zhao in #2386
- [Pix2Pix Zero] Fix slow tests by @sayakpaul in #2391
- [Pix2Pix] Add utility function by @patrickvonplaten in #2385
- Fix UniPC tests and remove some test warnings by @pcuenca in #2396
- [Pipelines] Add a section on generating captions and embeddings for Pix2Pix Zero by @sayakpaul in #2395
- Torch2.0 scaled_dot_product_attention processor by @patil-suraj in #2303
- add: inversion to pix2pix zero docs. by @sayakpaul in #2398
- Add semantic guidance pipeline by @manuelbrack in #2223
- Add ddim inversion pix2pix by @patrickvonplaten in #2397
- add MultiDiffusionPanorama pipeline by @omerbt in #2393
- Fixing typos in documentation by @anagri in #2389
- controlling generation docs by @williamberman in #2388
- apply_forward_hook simply returns if no accelerate by @daquexian in #2387
- Revert "Release: v0.13.0" by @williamberman in #2405
- controlling generation doc nits by @williamberman in #2406
- Fix typo in AttnProcessor2_0 symbol by @pcuenca in #2404
- add index page by @yiyixuxu in #2401
- add xformers 0.0.16 warning message by @williamberman in #2345
Significant community contributions
The following contributors have made significant changes to the library over the last release:
- @thedarkzeno
- Create train_dreambooth_inpaint_lora.py (#2205)
- @prathikr
- refactor onnxruntime integration (#2042)
- @Abhinay1997
- [Community Pipeline] UnCLIP Text Interpolation Pipeline (#2257)
- @SusungHong
- Add Self-Attention-Guided (SAG) Stable Diffusion pipeline (#2193)
- @wl-zhao
- @manuelbrack
- Add semantic guidance pipeline (#2223)
- @omerbt
- add MultiDiffusionPanorama pipeline (#2393)