Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UniDiffuser model and pipeline #2963

Merged
merged 334 commits into from May 26, 2023
Merged

Conversation

dg845
Copy link
Contributor

@dg845 dg845 commented Apr 4, 2023

This PR implements a pipeline for the UniDiffuser model as discussed in #2857.

Model/Pipeline Description

The UniDiffuser model (paper, code) is a multi-modal model which extends the DDPM model to model all distributions relevant to a set of multi-modal data. From the paper abstract:

This paper proposes a unified diffusion framework (dubbed UniDiffuser) to fit all distributions relevant to a set of multi-modal data in one model. Our key insight is – learning diffusion models for marginal, conditional, and joint distributions can be unified as predicting the noise in the perturbed data, where the perturbation levels (i.e. timesteps) can be different for different modalities. Inspired by the unified view, UniDiffuser learns all distributions simultaneously with a minimal modification to the original diffusion model – perturbs data in all modalities instead of a single modality, inputs individual timesteps in different modalities, and predicts the noise of all modalities instead of a single modality. UniDiffuser is parameterized by a transformer for diffusion models to handle input types of different modalities. Implemented on large-scale paired image-text data, UniDiffuser is able to perform image, text, text-to-image, image-to-text, and image-text pair generation by setting proper timesteps without additional overhead...

In this PR, we implement a image-text UniDiffuser model as described in the paper:

image

Usage Examples

import requests
import torch
from PIL import Image
from io import BytesIO

from diffusers import UniDiffuserPipeline

device = "cuda"
model_id_or_path = "dg845/unidiffuser-diffusers"
pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
pipe.to(device)

# Joint image-text generation. The generation task is automatically inferred.
sample = pipe(num_inference_steps=20, guidance_scale=8.0)
image = sample.images[0]
text = sample.text[0]
image.save("unidiffuser_sample_joint_image.png")
print(text)

# The mode can be set manually. The following is equivalent to the above:
pipe.set_joint_mode()
sample2 = pipe(num_inference_steps=20, guidance_scale=8.0)

# Note that if you set the mode manually the pipeline will no longer attempt
# to automatically infer the mode. You can re-enable this with reset_mode().
pipe.reset_mode()

# Text-to-image generation.
prompt = "an elephant under the sea"

sample = pipe(prompt=prompt, num_inference_steps=20, guidance_scale=8.0)
t2i_image = sample.images[0]
t2i_image.save("unidiffuser_sample_text2img_image.png")

# Image-to-text generation.
image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg"
response = requests.get(image_url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))

sample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0)
i2t_text = sample.text[0]
print(i2t_text)

# Image variation can be performed with a image-to-text generation followed by a text-to-image generation:
sample = pipe(prompt=i2t_text, num_inference_steps=20, guidance_scale=8.0)
final_image = sample.images[0]
final_image.save("unidiffuser_image_variation_sample.png")

# Text variation can be performed with a text-to-image generation followed by a image-to-text generation:
sample = pipe(image=t2i_image, num_inference_steps=20, guidance_scale=8.0)
final_prompt = sample.text[0]
print(final_prompt)

TODO

  • Implement UniDiffuserModel [U-ViT (paper, code)]
  • Implement UniDiffuserPipeline
  • Script to convert UniDiffuser checkpoints to diffusers checkpoints
  • Upload pre-trained UniDiffuser model [see this comment for more details]
  • Create tests for UniDiffuserPipeline
  • Create documentation for UniDiffuserPipeline
  • Add docstrings for model and pipeline
  • Add usage example(s)

Discussion

  • (TBD)

CC

@patrickvonplaten
@nemonameless
@baofff (author on original paper, author of original code)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 4, 2023

The documentation is not available anymore as the PR was closed or merged.

@dg845
Copy link
Contributor Author

dg845 commented Apr 4, 2023

Currently, the code in the PR isn't in a working state, and I haven't implemented tests or tested the code yet. I've opened the PR because I wanted to get some preliminary feedback on the design and code. In particular, I have the following questions:


Design Questions:

  1. Since the image-text UniDiffuser model is capable of doing marginal text or image generation, conditional text-to-image and image-to-text generation, and joint image-text generation, I've currently implemented the __call__ method to have a mode parameter that allows the user to generate text, images, text-conditioned images, etc. as desired. I'm not sure if this fits in with the pipeline design philosophy, particularly the principle that

Every pipeline should have one and only one way to run it via a __call__ method.

Would it be better if I split UniDiffuserPipeline into separate pipelines for each generation task: e.g. UniDiffuserTextToImagePipeline, etc., akin to VersatileDiffusionTextToImagePipeline, etc.?

  1. In particular, I would greatly appreciate some preliminary feedback on the main implemented classes: UniDiffuserPipeline, UniDiffuserModel, and UniDiffuserTextDecoder.

Questions about Tests:

  1. Is there a guide to writing tests for the diffusers library?
    1. Partial answer: I've found the transformers testing guide to be useful, and I think most of the stuff there is applicable to diffusers as well.
  2. Is there an easy way to find small model checkpoints for testing, such as analogues to CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")?
    1. The transformers testing guide suggests something like grep -r "tiny" tests/ examples/ to find examples of tiny models/pipelines/etc. for testing.
    2. I think the hf-internal-testing hub page should also list all such models.

[if there is a better place to move this discussion, please let me know :) ]

def __init__(
self,
tokenizer: GPT2Tokenizer,
text_decoder: GPT2LMHeadModel,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we try to seperate the tokenizer and text decoder here.

diffusers should be able to load the tokenizer out of the box, you just have to define it in the pipeline, e.g. here: https://github.com/huggingface/diffusers/pull/2963/files#r1159526676

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we cannot pass the text_decoder here at init as this would prevent us to be able to use from_pretained(...) of the model class. Could you maybe try to follow the design as done here:

class SpectrogramContEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):

See how we import blocks from transformers to design our own new model. I think here you could just do the following:

def __init__(
    self,
    num_layers=12,
    ...
):
    config = GPT2Config(...take all the config params from init)
    self.text_decoder = GPT2LMHeadModel(config)

We then design a new checkpoint architecture for the UniDiffusersTextDecoder and upload pretrained weights for it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the design of __init__ following the example: removed tokenizer and text_decoder args, added GPT2 config args.

eos = "<|EOS|>"
special_tokens_dict = {"eos_token": eos}
self.tokenizer = tokenizer
self.tokenizer.add_special_tokens(special_tokens_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we can do this directly for the uploaded tokenizer. E.g. let's just upload a tokenizer that has EOS already added so that we don't have to do it every time we call the model at init

More than happy to help here later on!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the tokenizer logic from __init__, will work on uploading the appropriate tokenizer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've prepared some native diffusers checkpoints for the current implementation of the UniDiffuserPipeline and its building blocks (e.g. UniDiffuserModel, UniDiffuserTextDecoder, etc.) [see the convert_to_ckpt.py script]. How can I upload these up to the hub?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to upload some models to the hub (see e.g. small test models here), but I'm confused about how to save/push to hub a tokenizer with added special tokens. The documentation for PreTrainedTokenizerBase.from_pretrained says that it won't save modifications to the tokenizer after initialization and I wasn't able to find any resources on how to do it after searching.

For reference, the code in the base unidiffuser library is something like

eos = '<|EOS|>'
special_tokens_dict = {'eos_token': eos}
base_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
base_tokenizer.add_special_tokens(special_tokens_dict)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding uploading the weights and the new tokenizer, you can can call push_to_hub() directly on the model.

So, for example (considering UniDiffuserModel is already populated with the pre-trained checkpoints):

unidiffusers = UniDiffuserModel(...)
unidiffusers.push_to_hub("your_hub_user_name/model_id")

Same applies for the rest of the models and the tokenizer.

self.transformer = text_decoder
# TODO: need to set the eos_token_id correctly
self.transformer.config.eos_token_id = self.tokenizer.eos_token_id
self.transformer.resize_token_embeddings(len(self.tokenizer))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also make sure that the GPT2Transformer has the correct number of word embeddings before loading it so that we don't have to always resize the embedding every time at init

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. I would prefer to have the decoder with the rejigged embeddings on the Hub rather than rejigging on the fly.

return generated_captions

@torch.no_grad()
def generate_beam(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

works for me!

"""

@register_to_config
def __init__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The design here looks good to me! Note that I think we can remove some redundant code that is not needed for this use case. I think you only need one of the three cases:

        self.is_input_continuous = (in_channels is not None) and (patch_size is None)
        self.is_input_vectorized = num_vector_embeds is not None
        self.is_input_patches = in_channels is not None and patch_size is not None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should have removed most of the redundant code (kept only the code handling the patch input case, since that's what the original UniDiffuser implementation used).

@patrickvonplaten
Copy link
Contributor

Great first design! I left some comments directly in the code. In short I think the general design is very nice - the models should be defined under the pipeline folder just like you do and the pipeline also looks quite nice already.

Answering your questions in line

Currently, the code in the PR isn't in a working state, and I haven't implemented tests or tested the code yet. I've opened the PR because I wanted to get some preliminary feedback on the design and code. In particular, I have the following questions:

Design Questions:

  1. Since the image-text UniDiffuser model is capable of doing marginal text or image generation, conditional text-to-image and image-to-text generation, and joint image-text generation, I've currently implemented the __call__ method to have a mode parameter that allows the user to generate text, images, text-conditioned images, etc. as desired. I'm not sure if this fits in with the pipeline design philosophy, particularly the principle that

Every pipeline should have one and only one way to run it via a __call__ method.

Would it be better if I split UniDiffuserPipeline into separate pipelines for each generation task: e.g. UniDiffuserTextToImagePipeline, etc., akin to VersatileDiffusionTextToImagePipeline, etc.?

I think since the purpose of UniDiffusers is exactly to bring all modes into the same distribution, one pipeline is nice here. So this design works for me. I'd maybe just not have a "mode" call input, but instead automatically decide the mode depending on what the user puts in. E.g. if the user just passes a "text" input, we're in text2img mode, if just a "image" input, we're in image to text mode => would this design work or are the inputs not enough to define which mode one is in? E.g. are muiltple modes possible for the same input combination?

  1. In particular, I would greatly appreciate some preliminary feedback on the main implemented classes: UniDiffuserPipeline, UniDiffuserModel, and UniDiffuserTextDecoder.

Left comments mostly directly in the code. In short:
UniDiffuserPipeline - looks good already, just:

  • let's make the tokenizer directly an input
  • if possible remove the mode input or if we can't make it maybe a setter variable pipe.set_text_to_image()

UniDiffuserModel - looks good, let's just remove all code that we don't need

UniDiffuserTextDecoder - here we need to change the init design a bit so that it would work flawlessly with from_pretrained(...) e.g. we can have models such as gpt2lmhead in the init (left some comments diretly in the code)

Questions about Tests:

  1. Is there a guide to writing tests for the diffusers library?

    1. Partial answer: I've found the transformers testing guide to be useful, and I think most of the stuff there is applicable to diffusers as well.
  2. Is there an easy way to find small model checkpoints for testing, such as analogues to CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")?

Not really. Some guides that could help:

  1. The transformers testing guide suggests something like grep -r "tiny" tests/ examples/ to find examples of tiny models/pipelines/etc. for testing.
  2. I think the hf-internal-testing hub page should also list all such models.

Regarding tiny models, yeah we just create them ourselves. What you can do here is to just load tiny configs to create random tiny models and use those for faster testing :-)

[if there is a better place to move this discussion, please let me know :) ]

Hope this helps a bit so that you can move forward, let me know if you need more help!

@dg845
Copy link
Contributor Author

dg845 commented Apr 7, 2023

Thanks for the review! With regards to this:

E.g. if the user just passes a "text" input, we're in text2img mode, if just a "image" input, we're in image to text mode => would this design work or are the inputs not enough to define which mode one is in? E.g. are muiltple modes possible for the same input combination?

for the currently supported modes, there is some ambiguity when neither text nor image input is provided. In this case, we cannot be sure whether the user wants unconditional ("marginal") image generation, unconditional ("marginal") text generation, or joint image-text generation.

The original code additionally supports image variation ("img2text2img") and text variation ("text2img2text") modes, whose inputs would be the same as the image-to-text (a conditioning image) and text-to-image (a conditioning prompt) modes, respectively. So supporting these modes would also cause some ambiguity.

So perhaps we could infer the mode in __call__, with e.g. only text input defaulting to the text2img mode and providing neither text nor image input defaulting to img mode. We would also provide setter variables pipe.set_text_to_image(), pipe.set_text_variation(), etc. and if the user uses the setter variables, we would respect that mode over the inferred mode.

[Just as a side note, the image variation implementation is different between StableDiffusionImageVariationPipeline and UniDiffuser. The Stable Diffusion image variation model is essentially a Stable Diffusion model trained with the text encoder swapped out for a CLIP-based image encoder (see here), but UniDiffuser uses its ability to do both image-conditioned text generation and text-conditioned image generation to do a "round-trip translation" of the image into text and back to an image. Not sure this is relevant to the discussion, just something I found interesting :). ]

[Edit: pushed new commit with possible implementation as described above]

@nemonameless
Copy link

I have referenced some codes of yours and combined with mine, and also submited an initial version PR PaddlePaddle/PaddleNLP#5487 , hope to learn from each other and contribute to the community

@dg845
Copy link
Contributor Author

dg845 commented Apr 8, 2023

Hi @patrickvonplaten and @baofff,

In looking at the noise prediction model architecture, I'm using BasicTransformerBlock as my transformer block, which I've noticed has two main differences as compared to the Block implementation in the original code:

  1. BasicTransformerBlock is pre-LayerNorm, while Block is post-LayerNorm.
  2. Block has the LayerNorms on the residual backbone of the block, whereas BasicTransformerBlock does not. (That is, in Block, the layer norm is applied after the skip connections, rather than before.)

In light of this, I have the following questions:

  1. Should we expect a big difference in performance between the two implementations for inference? (The paper reports that using a pre-LayerNorm transformer is numerically unstable when training a UniDiffuser model.)
  2. Should I follow the original implementation as closely as possible?

@dg845
Copy link
Contributor Author

dg845 commented Apr 15, 2023

As a note, if you want to look at the code I used to calculate the expected_slices for the fast default tests, you can look at https://github.com/dg845/unidiffuser/blob/test_sampling/sample_test_v1.py.

@patrickvonplaten
Copy link
Contributor

Very cool! This looks like almost ready to be merged to me - thanks a lot for re-iterating on the design :-)

@patrickvonplaten
Copy link
Contributor

@williamberman @sayakpaul when you have a moment, it'd be super cool if you could review

stop_token: str = "<|EOS|>",
):
"""
Generates text using the given tokenizer and text prompt or token embedding via beam search.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Help me understand this a bit. Why would there be a need to generate text from a given text prompt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I think I wrote this docstring in a confusing way. In the context of UniDiffuser sampling, we use this function to generate output text (when appropriate) from the text latents after we process the CLIP-embedded input prompt using the unet (UniDiffuserModel) model. The method accepts both prompt and embed arguments, for input tokens and embeddings respectively, but we only ever call it with input embeddings (as described above):

generated_captions.append(self.generate_beam(tokenizer, embed=feature, device=device)[0])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh okay. Yeah, then I guess we need to it make it a bit clearer from the code?

labels (`torch.Tensor`, *optional*):
TODO
"""
embedding_text = self.transformer.transformer.wte(tokens)
Copy link
Member

@sayakpaul sayakpaul Apr 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why transformer.transformer?

Copy link
Contributor Author

@dg845 dg845 Apr 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took the forward(...) method from the original code. Upon reviewing the code, I think the forward method was probably intended to do the following: the tokens argument is a sequence of input vocab token IDs for the GPT2LMHeadModel, while the prefix argument is the hidden state of another model (e.g. something like transformers.modeling_outputs.BaseModelOutputWithPooling.last_hidden_state of a CLIPTextModel). prefix then gets converted to an intermediate representation via self.encode_prefix(...) and then converted into the latent space of the GPT model via self.decode_prefix(...) (if they are being used). We then combine the embedding of tokens with the prefix embedding and then do a forward pass of the internal GPT2LMHeadModel.

I guess it's confusing currently because on lines 52-54 instead of using n_embd as the input dimension to nn.Linear we should instead have a new argument prefix_inner_dim and use that, e.g.

self.encode_prefix = (
        nn.Linear(prefix_inner_dim, self.prefix_hidden_dim) if self.prefix_hidden_dim is not None else nn.Identity()
)

Furthermore, prefix_hidden_dim should probably always need to be supplied, since prefix_inner_dim and n_embd are in general not guaranteed to be the same.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for explaining!

I guess we will know better once we start testing the code.

@sayakpaul
Copy link
Member

Design looks quite clean and matured to me. My questions / comments are very minor ones and are probably already covered by @patrickvonplaten's comments.

I think since the purpose of UniDiffusers is exactly to bring all modes into the same distribution, one pipeline is nice here. So this design works for me. I'd maybe just not have a "mode" call input, but instead automatically decide the mode depending on what the user puts in. E.g. if the user just passes a "text" input, we're in text2img mode, if just a "image" input, we're in image to text mode => would this design work or are the inputs not enough to define which mode one is in? E.g. are muiltple modes possible for the same input combination?

+1 to this.

@dg845
Copy link
Contributor Author

dg845 commented Apr 25, 2023

I've uploaded a diffusers version of the unidiffuser-v1 checkpoint at https://huggingface.co/dg845/unidiffuser-diffusers and a small random testing pipeline at https://huggingface.co/dg845/unidiffuser-diffusers-test. Note that the text_encoder is from openai/clip-vit-large-patch14 and the image_encoder and image_processor are from openai/clip-vit-base-patch32, which should match the frozen CLIP encoders used by the original implementation. text_tokenizer should have the new EOS token added, and text_decoder should have its embeddings appropriately resized for the new token.

I've also opened a PR at hf-internal-testing/diffusers-images in the hub to add an example image for UniDiffuserPipeline testing.

@sayakpaul
Copy link
Member

This is great! Thanks so much for your efforts. I think now the TODOs are:

I have also merged your PR. So, hopefully, this unblocks you. @patrickvonplaten can help us with the repo transfers.

@patrickvonplaten
Copy link
Contributor

Let me know once you need help with a model transfer

ernestchu and others added 4 commits May 5, 2023 07:22
* Fix a bug of pano when not doing CFG

* enhance code quality

* apply formatting.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* fix progress bar issue in pipeline_text_to_video_zero.py. Copy scheduler after first backward

* fix tensor loading in test_text_to_video_zero.py

* make style && make quality
* fix: norm group test for UNet3D.

* chore: speed up the panorama tests (fast).

* set default value of _test_inference_batch_single_identical.

* fix: batch_sizes default value.
Comment on lines 34 to 36
| Pipeline | Tasks | Demo
|---|---|:---:|
| [UniDiffuserPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_unidiffuser.py) | *Joint Image-Text Gen*, *Text-to-Image*, *Image-to-Text*, *Image Gen*, *Text Gen*, *Image Variation*, *Text Variation* | |
Copy link
Member

@sayakpaul sayakpaul May 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
| Pipeline | Tasks | Demo
|---|---|:---:|
| [UniDiffuserPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_unidiffuser.py) | *Joint Image-Text Gen*, *Text-to-Image*, *Image-to-Text*, *Image Gen*, *Text Gen*, *Image Variation*, *Text Variation* | |
| Pipeline | Tasks | Demo | Colab |
|---|---|:---:|
| [UniDiffuserPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_unidiffuser.py) | *Joint Image-Text Gen*, *Text-to-Image*, *Image-to-Text*, *Image Gen*, *Text Gen*, *Image Variation*, *Text Variation* | [🤗 Spaces](https://huggingface.co/spaces/thu-ml/unidiffuser) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/unidiffuser.ipynb) |

For now, let's add a link to the original demo. @hysts is working on to change the demo to have diffusers usage.

Copy link
Member

@sayakpaul sayakpaul May 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prepared a Colab Notebook from your awesome documentation: huggingface/notebooks#377

Also prepared this GIF to showcase the powerfulness of the pipeline:
unidiffuser

Comment on lines 120 to 140
import requests
import torch
from PIL import Image
from io import BytesIO

from diffusers import UniDiffuserPipeline

device = "cuda"
model_id_or_path = "thu-ml/unidiffuser-v1"
pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
pipe.to(device)

# Image-to-text generation
image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg"
response = requests.get(image_url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))

sample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0)
i2t_text = sample.text[0]
print(text)
Copy link
Member

@sayakpaul sayakpaul May 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import requests
import torch
from PIL import Image
from io import BytesIO
from diffusers import UniDiffuserPipeline
device = "cuda"
model_id_or_path = "thu-ml/unidiffuser-v1"
pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
pipe.to(device)
# Image-to-text generation
image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg"
response = requests.get(image_url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))
sample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0)
i2t_text = sample.text[0]
print(text)
from diffusers import UniDiffuserPipeline
from diffusers.utils import load_image
device = "cuda"
model_id_or_path = "thu-ml/unidiffuser-v1"
pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
pipe.to(device)
# Image-to-text generation
image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg"
init_image = load_image(image_url).resize((512, 512))
sample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0)
i2t_text = sample.text[0]
print(i2t_text)

Reduces the LoC :)

Comment on lines 163 to 168
# Image variation can be performed with a image-to-text generation followed by a text-to-image generation:
# 1. Image-to-text generation
image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg"
response = requests.get(image_url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can follow the same as https://github.com/huggingface/diffusers/pull/2963/files#r1205061596 for loading and resizing the image?

- all
- __call__

## ImageTextPipelineOutput
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also wanted to know if there's any argument control the number of images / text I wanted to generate as a part of the variation mode.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can control the num_images_per_prompt in the text-to-image mode, so that's settled. But what about text variation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For modes which generate only text (img2text and text), there's an analogous num_prompts_per_image argument to __call__ . So when you perform the second img2text generation for text variation you can specify num_prompts_per_image > 1 to get multiple text variation samples.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go here:

https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/outputs.mdx

It feels more natural to me to have the documentation for ImageTextPipelineOutput alongside ImagePipelineOutput, which is at the Diffusion Pipeline doc page.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've gone ahead and moved the ImageTextPipelineOutput documentation to /api/diffusion_pipeline.mdx (alongside the ImagePipelineOutput and AudioPipelineOutput documentation). Let me know if it would be better somewhere else (for example, at /api/outputs.mdx as originally suggested) :).

Copy link
Member

@sayakpaul sayakpaul May 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand. But we will soon update that too :)

Cc: @patrickvonplaten

Copy link
Contributor Author

@dg845 dg845 May 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so would it be better if I move it to /api/outputs? Or is it fine to leave it at /api/diffusion_pipeline for now?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it as is for now. Then we will bulk move things :)


sample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0)
i2t_text = sample.text[0]
print(text)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: should be i2t_text.


### Unconditional Image and Text Generation

Unconditional generation (where we start from only latents sampled from a standard Gaussian prior) from a `UniDiffuserPipeline` will produce a (image, text) pair:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Unconditional generation (where we start from only latents sampled from a standard Gaussian prior) from a `UniDiffuserPipeline` will produce a (image, text) pair:
Unconditional generation (where we start from only latents sampled from a standard Gaussian prior) from a [`UniDiffuserPipeline`] will produce a (image, text) pair:

So, that the hyperlink is automatically rendered.

print(text)
```

The `img2text` mode requires that an input `image` be supplied. You can set the `img2text` mode manually with [`UniDiffuser.set_image_to_text_mode`].
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The `img2text` mode requires that an input `image` be supplied. You can set the `img2text` mode manually with [`UniDiffuser.set_image_to_text_mode`].
The `img2text` mode requires that an input `image` be supplied. You can set the `img2text` mode manually with [`UniDiffuserPipeline.set_image_to_text_mode`].

self.transformer = GPT2LMHeadModel(gpt_config)

def forward(
self,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

return self.encode_prefix(prefix)

@torch.no_grad()
def generate_captions(self, features, eos_token_id, device):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Comment on lines 200 to 206
eos_token_id: Optional[int] = None,
input_ids=None,
input_embeds=None,
device=None,
beam_size: int = 5,
entry_length: int = 67,
temperature: float = 1.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
eos_token_id: Optional[int] = None,
input_ids=None,
input_embeds=None,
device=None,
beam_size: int = 5,
entry_length: int = 67,
temperature: float = 1.0,
input_ids=None,
input_embeds=None,
device=None,
beam_size: int = 5,
entry_length: int = 67,
temperature: float = 1.0,
eos_token_id: Optional[int] = None,

(nit) Let's change the order here maybe since the eos_token_id should probably not be the first input

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed :).

cross_attention_kwargs=None,
class_labels=None,
):
# Pre-LayerNorm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great this PR looks good to go for me! Just one final tiny nit regarding the ordering of the generate input.

Apart from this, this is good to merge from my side :-) Incredible work here @dg845! This is really a difficult model with many components and the final implementation is super nice :-)

@sayakpaul
Copy link
Member

@dg845 once the conflicts are resolved and tests pass, we will merge :)

Meanwhile, I will also correct the gif.

Really amazing contribution. I hope the contribution experience was enjoyable for you.

@sayakpaul
Copy link
Member

@patrickvonplaten a friendly ping for these transfers:

#2963 (comment)

@dg845
Copy link
Contributor Author

dg845 commented May 26, 2023

Thanks! I really enjoyed working on this PR :). And thanks for all the advice and help along the way :).

@patrickvonplaten
Copy link
Contributor

@sayakpaul feel free to merge whenever! All good from my side

@sayakpaul sayakpaul merged commit 352ca31 into huggingface:main May 26, 2023
7 checks passed
@sayakpaul
Copy link
Member

@dg845 thanks again for your amazing contribution. The pipeline and the components are now live at: https://huggingface.co/docs/diffusers/main/en/api/pipelines/unidiffuser

@dg845 dg845 deleted the unidiffuser-pipeline branch May 27, 2023 02:09
@patrickvonplaten patrickvonplaten changed the title [WIP] Add UniDiffuser model and pipeline Add UniDiffuser model and pipeline May 30, 2023
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Fix a bug of pano when not doing CFG (#3030)

* Fix a bug of pano when not doing CFG

* enhance code quality

* apply formatting.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Text2video zero refinements (#3070)

* fix progress bar issue in pipeline_text_to_video_zero.py. Copy scheduler after first backward

* fix tensor loading in test_text_to_video_zero.py

* make style && make quality

* Release: v0.15.0

* [Tests] Speed up panorama tests (#3067)

* fix: norm group test for UNet3D.

* chore: speed up the panorama tests (fast).

* set default value of _test_inference_batch_single_identical.

* fix: batch_sizes default value.

* [Post release] v0.16.0dev (#3072)

* Adds profiling flags, computes train metrics average. (#3053)

* WIP controlnet training

- bugfix --streaming
- bugfix running report_to!='wandb'
- adds memory profile before validation

* Adds final logging statement.

* Sets train epochs to 11.

Looking at a longer ~16ep run, we see only good validation images
after ~11ep:

https://wandb.ai/andsteing/controlnet_fill50k/runs/3j2hx6n8

* Removes --logging_dir (it's not used).

* Adds --profile flags.

* Updates --output_dir=runs/fill-circle-{timestamp}.

* Compute mean of `train_metrics`.

Previously `train_metrics[-1]` was logged, resulting in very bumpy train
metrics.

* Improves logging a bit.

- adds l2_grads gradient norm logging
- adds steps_per_sec
- sets walltime as x coordinate of train/step
- logs controlnet_params config

* Adds --ccache (doesn't really help though).

* minor fix in controlnet flax example (#2986)

* fix the error when push_to_hub but not log validation

* contronet_from_pt & controlnet_revision

* add intermediate checkpointing to the guide

* Bugfix --profile_steps

* Sets `RACKER_PROJECT_NAME='controlnet_fill50k'`.

* Logs fractional epoch.

* Adds relative `walltime` metric.

* Adds `StepTraceAnnotation` and uses `global_step` insetad of `step`.

* Applied `black`.

* Streamlines commands in README a bit.

* Removes `--ccache`.

This makes only a very small difference (~1 min) with this model size, so removing
the option introduced in cdb3cc.

* Re-ran `black`.

* Update examples/controlnet/README.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Converts spaces to tab.

* Removes repeated args.

* Skips first step (compilation) in profiling

* Updates README with profiling instructions.

* Unifies tabs/spaces in README.

* Re-ran style & quality.

---------
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* [Pipelines] Make sure that None functions are correctly not saved (#3080)

* doc string example remove from_pt (#3083)

* [Tests] parallelize (#3078)

* [Tests] parallelize

* finish folder structuring

* Parallelize tests more

* Correct saving of pipelines

* make sure logging level is correct

* try again

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Throw deprecation warning for return_cached_folder (#3092)

Throw deprecation warning

* Allow SD attend and excite pipeline to work with any size output images (#2835)

Allow stable diffusion attend and excite pipeline to work with any size output image. Re: #2476, #2603

* [docs] Update community pipeline docs (#2989)

* update community pipeline docs

* fix formatting

* explain sharing workflows

* Add to support Guess Mode for StableDiffusionControlnetPipleline (#2998)

* add guess mode (WIP)

* fix uncond/cond order

* support guidance_scale=1.0 and batch != 1

* remove magic coeff

* add docstring

* add intergration test

* add document to controlnet.mdx

* made the comments a bit more explanatory

* fix table

* fix default value for attend-and-excite (#3099)

* fix default

* remvoe one line as requested by gc team  (#3077)

remvoe one line

* ddpm custom timesteps (#3007)

add custom timesteps test

add custom timesteps descending order check

docs

timesteps -> custom_timesteps

can only pass one of num_inference_steps and timesteps

* Fix breaking change in `pipeline_stable_diffusion_controlnet.py` (#3118)

fix breaking change

* Add global pooling to controlnet (#3121)

* [Bug fix] Fix img2img processor with safety checker (#3127)

Fix img2img processor with safety checker

* [Bug fix] Make sure correct timesteps are chosen for img2img (#3128)

Make sure correct timesteps are chosen for img2img

* Improve deprecation warnings (#3131)

* Fix config deprecation (#3129)

* Better deprecation message

* Better deprecation message

* Better doc string

* Fixes

* fix more

* fix more

* Improve __getattr__

* correct more

* fix more

* fix

* Improve more

* more improvements

* fix more

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* make style

* Fix all rest & add tests & remove old deprecation fns

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* feat: verfication of multi-gpu support for select examples. (#3126)

* feat: verfication of multi-gpu support for select examples.

* add: multi-gpu training sections to the relvant doc pages.

* speed up attend-and-excite fast tests (#3079)

* Optimize log_validation in train_controlnet_flax (#3110)

extract pipeline from log_validation

* make style

* Correct textual inversion readme (#3145)

* Update README.md

* Apply suggestions from code review

* Add unet act fn to other model components (#3136)

Adding act fn config to the unet timestep class embedding and conv
activation.

The custom activation defaults to silu which is the default
activation function for both the conv act and the timestep class
embeddings so default behavior is not changed.

The only unet which use the custom activation is the stable diffusion
latent upscaler https://huggingface.co/stabilityai/sd-x2-latent-upscaler/blob/main/unet/config.json
(I ran a script against the hub to confirm).
The latent upscaler does not use the conv activation nor the timestep
class embeddings so we don't change its behavior.

* class labels timestep embeddings projection dtype cast (#3137)

This mimics the dtype cast for the standard time embeddings

* [ckpt loader] Allow loading the Inpaint and Img2Img pipelines, while loading a ckpt model (#2705)

* [ckpt loader] Allow loading the Inpaint and Img2Img pipelines, while loading a ckpt model

* Address review comment from PR

* PyLint formatting

* Some more pylint fixes, unrelated to our change

* Another pylint fix

* Styling fix

* add from_ckpt method as Mixin (#2318)

* add mixin class for pipeline from original sd ckpt

* Improve

* make style

* merge main into

* Improve more

* fix more

* up

* Apply suggestions from code review

* finish docs

* rename

* make style

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Add TensorRT SD/txt2img Community Pipeline to diffusers along with TensorRT utils (#2974)

* Add SD/txt2img Community Pipeline to diffusers along with TensorRT utils

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>

* update installation command

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>

* update tensorrt installation

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>

* changes
1. Update setting of cache directory
2. Address comments: merge utils and pipeline code.
3. Address comments: Add section in README

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>

* apply make style

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>

---------

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Correct `Transformer2DModel.forward` docstring (#3074)

⚙️chore(transformer_2d) update function signature for encoder_hidden_states

* Update pipeline_stable_diffusion_inpaint_legacy.py (#2903)

* Update pipeline_stable_diffusion_inpaint_legacy.py

* fix preprocessing of Pil images with adequate batch size

* revert map

* add tests

* reformat

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* next try to fix the style

* wth is this

* Update testing_utils.py

* Update testing_utils.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Modified altdiffusion pipline to support altdiffusion-m18 (#2993)

* Modified altdiffusion pipline to support altdiffusion-m18

* Modified altdiffusion pipline to support altdiffusion-m18

* Modified altdiffusion pipline to support altdiffusion-m18

* Modified altdiffusion pipline to support altdiffusion-m18

* Modified altdiffusion pipline to support altdiffusion-m18

* Modified altdiffusion pipline to support altdiffusion-m18

* Modified altdiffusion pipline to support altdiffusion-m18

---------

Co-authored-by: root <fulong_ye@163.com>

* controlnet training resize inputs to multiple of 8 (#3135)

controlnet training center crop input images to multiple of 8

The pipeline code resizes inputs to multiples of 8.
Not doing this resizing in the training script is causing
the encoded image to have different height/width dimensions
than the encoded conditioning image (which uses a separate
encoder that's part of the controlnet model).

We resize and center crop the inputs to make sure they're the
same size (as well as all other images in the batch). We also
check that the initial resolution is a multiple of 8.

* adding custom diffusion training to diffusers examples (#3031)

* diffusers==0.14.0 update

* custom diffusion update

* custom diffusion update

* custom diffusion update

* custom diffusion update

* custom diffusion update

* custom diffusion update

* custom diffusion

* custom diffusion

* custom diffusion

* custom diffusion

* custom diffusion

* apply formatting and get rid of bare except.

* refactor readme and other minor changes.

* misc refactor.

* fix: repo_id issue and loaders logging bug.

* fix: save_model_card.

* fix: save_model_card.

* fix: save_model_card.

* add: doc entry.

* refactor doc,.

* custom diffusion

* custom diffusion

* custom diffusion

* apply style.

* remove tralining whitespace.

* fix: toctree entry.

* remove unnecessary print.

* custom diffusion

* custom diffusion

* custom diffusion test

* custom diffusion xformer update

* custom diffusion xformer update

* custom diffusion xformer update

---------

Co-authored-by: Nupur Kumari <nupurkumari@Nupurs-MacBook-Pro.local>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Nupur Kumari <nupurkumari@nupurs-mbp.wifi.local.cmu.edu>

* make style

* Update custom_diffusion.mdx (#3165)

Add missing newlines for rendering the links correctly

* Added distillation for quantization example on textual inversion. (#2760)

* Added distillation for quantization example on textual inversion.

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>

* refined readme and code style.

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>

* Update text2images.py

* refined code of model load and added compatibility check.

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>

* fixed code style.

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>

* fix C403 [*] Unnecessary `list` comprehension (rewrite as a `set` comprehension)

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>

---------

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>

* Update Noise Autocorrelation Loss Function for Pix2PixZero Pipeline (#2942)

* Update Pix2PixZero Auto-correlation Loss

* Add fast inversion tests

* Clarify purpose and mark as deprecated

Fix inversion prompt broadcasting

* Register modules set to `None` in config for `test_save_load_optional_components`

* Update new tests to coordinate with #2953

* [DreamBooth] add text encoder LoRA support in the DreamBooth training script (#3130)

* add: LoRA text encoder support for DreamBooth example.

* fix initialization.

* fix: modification call.

* add: entry in the readme.

* use dog dataset from hub.

* fix: params to clip.

* add entry to the LoRA doc.

* add: tests for lora.

* remove unnecessary list comprehension./

* Update Habana Gaudi documentation (#3169)

* Update Habana Gaudi doc

* Fix tables

* Add model offload to x4 upscaler (#3187)

* Add model offload to x4 upscaler

* fix

* [docs] Deterministic algorithms (#3172)

deterministic algos

* Update custom_diffusion.mdx to credit the author (#3163)

* Update custom_diffusion.mdx

* fix: unnecessary list comprehension.

* Fix TensorRT community pipeline device set function (#3157)

pass silence_dtype_warnings as kwarg

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make `from_flax` work for controlnet (#3161)

fix from_flax

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [docs] Clarify training args (#3146)

* clarify training arg

* apply feedback

* Multi Vector Textual Inversion (#3144)

* Multi Vector

* Improve

* fix multi token

* improve test

* make style

* Update examples/test_examples.py

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* update

* Finish

* Apply suggestions from code review

---------

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Add `Karras sigmas` to HeunDiscreteScheduler (#3160)

* Add karras pattern to discrete heun scheduler

* Add integration test

* Fix failing CI on pytorch test on M1 (mps)

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [AudioLDM] Fix dtype of returned waveform (#3189)

* Fix bug in train_dreambooth_lora (#3183)

* Update train_dreambooth_lora.py

fix bug

* Update train_dreambooth_lora.py

* [Community Pipelines] Update lpw_stable_diffusion pipeline (#3197)

* Update lpw_stable_diffusion.py

* fix cpu offload

* Make sure VAE attention works with Torch 2_0 (#3200)

* Make sure attention works with Torch 2_0

* make style

* Fix more

* Revert "[Community Pipelines] Update lpw_stable_diffusion pipeline" (#3201)

Revert "[Community Pipelines] Update lpw_stable_diffusion pipeline (#3197)"

This reverts commit 9965cb50eac12e397473f01535aab43aae76b4ab.

* [Bug fix] Fix batch size attention head size mismatch (#3214)

* fix mixed precision training on train_dreambooth_inpaint_lora (#3138)

cast to weight dtype

* adding enable_vae_tiling and disable_vae_tiling functions (#3225)

adding enable_vae_tiling and disable_val_tiling functions

* Add ControlNet v1.1 docs (#3226)

Add v1.1 docs

* Fix issue in maybe_convert_prompt (#3188)

When the token used for textual inversion does not have any special symbols (e.g. it is not surrounded by <>), the tokenizer does not properly split the replacement tokens.  Adding a space for the padding tokens fixes this.

* Sync cache version check from transformers (#3179)

sync cache version check from transformers

* Fix docs text inversion (#3166)

* Fix docs text inversion

* Apply suggestions from code review

* add model (#3230)

* add

* clean

* up

* clean up more

* fix more tests

* Improve docs further

* improve

* more fixes docs

* Improve docs more

* Update src/diffusers/models/unet_2d_condition.py

* fix

* up

* update doc links

* make fix-copies

* add safety checker and watermarker to stage 3 doc page code snippets

* speed optimizations docs

* memory optimization docs

* make style

* add watermarking snippets to doc string examples

* make style

* use pt_to_pil helper functions in doc strings

* skip mps tests

* Improve safety

* make style

* new logic

* fix

* fix bad onnx design

* make new stable diffusion upscale pipeline model arguments optional

* define has_nsfw_concept when non-pil output type

* lowercase linked to notebook name

---------

Co-authored-by: William Berman <WLBberman@gmail.com>

* Allow return pt x4 (#3236)

* Add all files

* update

* Allow fp16 attn for x4 upscaler (#3239)

* Add all files

* update

* Make sure vae is memory efficient for PT 1

* make style

* fix fast test (#3241)

* Adds a document on token merging (#3208)

* add document on token merging.

* fix headline.

* fix: headline.

* add some samples for comparison.

* [AudioLDM] Update docs to use updated ckpt (#3240)

* [AudioLDM] Update docs to use updated ckpt

* make style

* Release: v0.16.0

* Post release for 0.16.0 (#3244)

* Post release

* fix more

* [docs] only mention one stage (#3246)

* [docs] only mention one stage

* add blurb on auto accepting

---------

Co-authored-by: William Berman <WLBberman@gmail.com>

* Write model card in controlnet training script (#3229)

Write model card in controlnet training script.

* [2064]: Add stochastic sampler (sample_dpmpp_sde) (#3020)

* [2064]: Add stochastic sampler

* [2064]: Add stochastic sampler

* [2064]: Add stochastic sampler

* [2064]: Add stochastic sampler

* [2064]: Add stochastic sampler

* [2064]: Add stochastic sampler

* [2064]: Add stochastic sampler

* Review comments

* [Review comment]: Add is_torchsde_available()

* [Review comment]: Test and docs

* [Review comment]

* [Review comment]

* [Review comment]

* [Review comment]

* [Review comment]

---------

Co-authored-by: njindal <njindal@adobe.com>

* [Stochastic Sampler][Slow Test]: Cuda test fixes (#3257)

[Slow Test]: Cuda test fixes

Co-authored-by: njindal <njindal@adobe.com>

* Remove required from tracker_project_name (#3260)

Remove required from tracker_project_name.

As observed by https://github.com/off99555 in https://github.com/huggingface/diffusers/issues/2695#issuecomment-1470755050, it already has a default value.

* adding required parameters while calling the get_up_block and get_down_block  (#3210)

* removed unnecessary parameters from get_up_block and get_down_block functions

* adding resnet_skip_time_act, resnet_out_scale_factor and cross_attention_norm to get_up_block and get_down_block functions

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* [docs] Update interface in repaint.mdx (#3119)

Update repaint.mdx

accomodate to #1701

* Update IF name to XL (#3262)

Co-authored-by: multimodalart <joaopaulo.passos+multimodal@gmail.com>

* fix typo in score sde pipeline (#3132)

* Fix typo in textual inversion JAX training script (#3123)

The pipeline is built as `pipe` but then used as `pipeline`.

* AudioDiffusionPipeline - fix encode method after config changes (#3114)

* config fixes

* deprecate get_input_dims

* Revert "Revert "[Community Pipelines] Update lpw_stable_diffusion pipeline"" (#3265)

Revert "Revert "[Community Pipelines] Update lpw_stable_diffusion pipeline" (#3201)"

This reverts commit 91a2a80eb2f98a9f64b9e287715add244dc6f2f3.

* Fix community pipelines (#3266)

* update notebook (#3259)

Co-authored-by: yiyixuxu <yixu@yis-macbook-pro.lan>

* [docs] add notes for stateful model changes (#3252)

* [docs] add notes for stateful model changes

* Update docs/source/en/optimization/fp16.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* link to accelerate docs for discarding hooks

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* [LoRA] quality of life improvements in the loading semantics and docs (#3180)

* 👽 qol improvements for LoRA.

* better function name?

* fix: LoRA weight loading with the new format.

* address Patrick's comments.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* change wording around encouraging the use of load_lora_weights().

* fix: function name.

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [Community Pipelines] EDICT pipeline implementation (#3153)

* EDICT pipeline initial commit

- Starting point taking from https://github.com/Joqsan/edict-diffusion

* refactor __init__() method

* minor refactoring

* refactor scheduler code

- remove scheduler and move its methods to the EDICTPipeline class

* make CFG optional
- refactor encode_prompt().
- include optional generator for sampling with vae.
- minor variable renaming

* add EDICT pipeline description to README.md

* replace preprocess() with VaeImageProcessor

* run make style and make quality commands

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [Docs]zh translated docs update (#3245)

* zh translated docs update

* update _toctree

* Update logging.mdx (#2863)

Fix typos

* Add multiple conditions to StableDiffusionControlNetInpaintPipeline (#3125)

* try multi controlnet inpaint

* multi controlnet inpaint

* multi controlnet inpaint

* Let's make sure that dreambooth always uploads to the Hub (#3272)

* Update Dreambooth README

* Adapt all docs as well

* automatically write model card

* fix

* make style

* Diffedit Zero-Shot Inpainting Pipeline (#2837)

* Update Pix2PixZero Auto-correlation Loss

* Add Stable Diffusion DiffEdit pipeline

* Add draft documentation and import code

* Bugfixes and refactoring

* Add option to not decode latents in the inversion process

* Harmonize preprocessing

* Revert "Update Pix2PixZero Auto-correlation Loss"

This reverts commit b218062fed08d6cc164206d6cb852b2b7b00847a.

* Update annotations

* rename `compute_mask` to `generate_mask`

* Update documentation

* Update docs

* Update Docs

* Fix copy

* Change shape of output latents to batch first

* Update docs

* Add first draft for tests

* Bugfix and update tests

* Add `cross_attention_kwargs` support for all pipeline methods

* Fix Copies

* Add support for PIL image latents

Add support for mask broadcasting

Update docs and tests

Align `mask` argument to `mask_image`

Remove height and width arguments

* Enable MPS Tests

* Move example docstrings

* Fix test

* Fix test

* fix pipeline inheritance

* Harmonize `prepare_image_latents` with StableDiffusionPix2PixZeroPipeline

* Register modules set to `None` in config for `test_save_load_optional_components`

* Move fixed logic to specific test class

* Clean changes to other pipelines

* Update new tests to coordinate with #2953

* Update slow tests for better results

* Safety to avoid potential problems with torch.inference_mode

* Add reference in SD Pipeline Overview

* Fix tests again

* Enforce determinism in noise for generate_mask

* Fix copies

* Widen test tolerance for fp16 based on `test_stable_diffusion_upscale_pipeline_fp16`

* Add LoraLoaderMixin and update `prepare_image_latents`

* clean up repeat and reg

* bugfix

* Remove invalid args from docs

Suppress spurious warning by repeating image before latent to mask gen

* add constant learning rate with custom rule (#3133)

* add constant lr with rules

* add constant with rules in TYPE_TO_SCHEDULER_FUNCTION

* add constant lr rate with rule

* hotfix code quality

* fix doc style

* change name constant_with_rules to piecewise constant

* Allow disabling torch 2_0 attention (#3273)

* Allow disabling torch 2_0 attention

* make style

* Update src/diffusers/models/attention.py

* [doc] add link to training script (#3271)

add link to training script

Co-authored-by: yiyixuxu <yixu@yis-macbook-pro.lan>

* temp disable spectogram diffusion tests (#3278)

The note-seq package throws an error on import because the default installed version of Ipython
is not compatible with python 3.8 which we run in the CI.
https://github.com/huggingface/diffusers/actions/runs/4830121056/jobs/8605954838#step:7:9

* Changed sample[0] to images[0] (#3304)

A pipeline object stores the results in `images` not in `sample`.
Current code blocks don't work.

* Typo in tutorial (#3295)

* Torch compile graph fix (#3286)

* fix more

* Fix more

* fix more

* Apply suggestions from code review

* fix

* make style

* make fix-copies

* fix

* make sure torch compile

* Clean

* fix test

* Postprocessing refactor img2img (#3268)

* refactor img2img VaeImageProcessor.postprocess

* remove copy from for init, run_safety_checker, decode_latents

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

---------

Co-authored-by: yiyixuxu <yixu@yis-macbook-pro.lan>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* [Torch 2.0 compile] Fix more torch compile breaks (#3313)

* Fix more torch compile breaks

* add tests

* Fix all

* fix controlnet

* fix more

* Add Horace He as co-author.
>
>
Co-authored-by: Horace He <horacehe2007@yahoo.com>

* Add Horace He as co-author.

Co-authored-by: Horace He <horacehe2007@yahoo.com>

---------

Co-authored-by: Horace He <horacehe2007@yahoo.com>

* fix: scale_lr and sync example readme and docs. (#3299)

* fix: scale_lr and sync example readme and docs.

* fix doc link.

* Update stable_diffusion.mdx (#3310)

fixed import statement

* Fix missing variable assign in DeepFloyd-IF-II (#3315)

Fix missing variable assign

lol

* Correct doc build for patch releases (#3316)

Update build_documentation.yml

* Add Stable Diffusion RePaint to community pipelines (#3320)

* Add Stable Diffsuion RePaint to community pipelines

- Adds Stable Diffsuion RePaint to community pipelines
- Add Readme enty for pipeline

* Fix: Remove wrong import

- Remove wrong import
- Minor change in comments

* Fix: Code formatting of stable_diffusion_repaint

* Fix: ruff errors in stable_diffusion_repaint

* Fix multistep dpmsolver for cosine schedule (suitable for deepfloyd-if) (#3314)

* fix multistep dpmsolver for cosine schedule (deepfloy-if)

* fix a typo

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* update all dpmsolver (singlestep, multistep, dpm, dpm++) for cosine noise schedule

* add test, fix style

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [docs] Improve LoRA docs (#3311)

* update docs

* add to toctree

* apply feedback

* Added input pretubation (#3292)

* Added input pretubation

* Fixed spelling

* Update write_own_pipeline.mdx (#3323)

* update controlling generation doc with latest goodies. (#3321)

* [Quality] Make style (#3341)

* Fix config dpm (#3343)

* Add the SDE variant of DPM-Solver and DPM-Solver++ (#3344)

* add SDE variant of DPM-Solver and DPM-Solver++

* add test

* fix typo

* fix typo

* Add upsample_size to AttnUpBlock2D, AttnDownBlock2D (#3275)

The argument `upsample_size` needs to be added to these modules to allow compatibility with other blocks that require this argument.

* Add UniDiffuser classes to __init__ files, modify transformer block to support pre- and post-LN, add fast default tests, fix some bugs.

* Update fast tests to use test checkpoints stored on the hub and to better match the reference UniDiffuser implementation.

* Fix code with make style.

* Revert "Fix code style with make style."

This reverts commit 10a174a12c82e6abd3d5a57665719a03dbb85ca7.

* Add self.image_encoder, self.text_decoder to list of models to offload to CPU in the enable_sequential_cpu_offload(...)/enable_model_cpu_offload(...) methods to make test_cpu_offload_forward_pass pass.

* Fix code quality with make style.

* Support using a data type embedding for UniDiffuser-v1.

* Add fast test for checking UniDiffuser-v1 sampling.

* Make changes so that the repository consistency tests pass.

* Add UniDiffuser dummy objects via make fix-copies.

* Fix bugs and make improvements to the UniDiffuser pipeline:
	- Improve batch size inference and fix bugs when num_images_per_prompt or num_prompts_per_image > 1
	- Add tests for num_images_per_prompt, num_prompts_per_image > 1
	- Improve check_inputs, especially regarding checking supplied latents
	- Add reset_mode method so that mode inference can be re-enabled after mode is set manually
	- Fix some warnings related to accessing class members directly instead of through their config
	- Small amount of refactoring in pipeline_unidiffuser.py

* Fix code style with make style.

* Add/edit docstrings for added classes and public pipeline methods. Also do some light refactoring.

* Add documentation for UniDiffuser and fix some typos/formatting in docstrings.

* Fix code with make style.

* Refactor and improve the UniDiffuser convert_from_ckpt.py script.

* Move the UniDiffusers convert_from_ckpy.py script to diffusers/scripts/convert_unidiffuser_to_diffusers.py

* Fix code quality via make style.

* Improve UniDiffuser slow tests.

* make style

* Fix some typos in the UniDiffuser docs.

* Remove outdated logic based on transformers version in UniDiffuser pipeline __init__.py

* Remove dependency on einops by refactoring einops operations to pure torch operations.

* make style

* Add slow test on full checkpoint for joint mode and correct expected image slices/text prefixes.

* make style

* Fix mixed precision issue by wrapping the offending code with the torch.autocast context manager.

* Revert "Fix mixed precision issue by wrapping the offending code with the torch.autocast context manager."

This reverts commit 1a58958ab4f024dbc4c90a6404c2e66210db6d00.

* Add fast test for CUDA/fp16 model behavior (currently failing).

* Fix the mixed precision issue and add additional tests of the pipeline cuda/fp16 functionality.

* make style

* Use a CLIPVisionModelWithProjection instead of CLIPVisionModel for image_encoder to better match the original UniDiffuser implementation.

* Make style and remove some testing code.

* Fix shape errors for the 'joint' and 'img2text' modes.

* Fix tests and remove some testing code.

* Add option to use fixed latents for UniDiffuserPipelineSlowTests and fix issue in modeling_text_decoder.py.

* Improve UniDiffuser docs, particularly the usage examples, and improve slow tests with new expected outputs.

* make style

* Fix examples to load model in float16.

* In image-to-text mode, sample from the autoencoder moment distribution instead of always getting its mode.

* make style

* When encoding the image using the VAE, scale the image latents by the VAE's scaling factor.

* make style

* Clean up code and make slow tests pass.

* make fix-copies

* [docs] Fix docstring (#3334)

fix docstring

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* if dreambooth lora (#3360)

* update IF stage I pipelines

add fixed variance schedulers and lora loading

* added kv lora attn processor

* allow loading into alternative lora attn processor

* make vae optional

* throw away predicted variance

* allow loading into added kv lora layer

* allow load T5

* allow pre compute text embeddings

* set new variance type in schedulers

* fix copies

* refactor all prompt embedding code

class prompts are now included in pre-encoding code
max tokenizer length is now configurable
embedding attention mask is now configurable

* fix for when variance type is not defined on scheduler

* do not pre compute validation prompt if not present

* add example test for if lora dreambooth

* add check for train text encoder and pre compute text embeddings

* Postprocessing refactor all others (#3337)

* add text2img

* fix-copies

* add

* add all other pipelines

* add

* add

* add

* add

* add

* make style

* style + fix copies

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>

* [docs] Improve safetensors docstring (#3368)

* clarify safetensor docstring

* fix typo

* apply feedback

* add: a warning message when using xformers in a PT 2.0 env. (#3365)

* add: a warning message when using xformers in a PT 2.0 env.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* StableDiffusionInpaintingPipeline - resize image w.r.t height and width (#3322)

* StableDiffusionInpaintingPipeline now resizes input images and masks w.r.t to passed input height and width. Default is already set to 512. This addresses the common tensor mismatch error. Also moved type check into relevant funciton to keep main pipeline body tidy.

* Fixed StableDiffusionInpaintingPrepareMaskAndMaskedImageTests

Due to previous commit these tests were failing as height and width need to be passed into the prepare_mask_and_masked_image function, I have updated the code and added a height/width variable per unit test as it seemed more appropriate than the current hard coded solution

* Added a resolution test to StableDiffusionInpaintPipelineSlowTests

this unit test simply gets the input and resizes it into some that would fail (e.g. would throw a tensor mismatch error/not a mult of 8). Then passes it through the pipeline and verifies it produces output with correct dims w.r.t the passed height and width

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make style

* [docs] Adapt a model (#3326)

* first draft

* apply feedback

* conv_in.weight thrown away

* [docs] Load safetensors (#3333)

* safetensors

* apply feedback

* apply feedback

* Apply suggestions from code review

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make style

* [Docs] Fix stable_diffusion.mdx typo (#3398)

Fix typo in last code block. Correct "prommpts" to "prompt"

* Support ControlNet v1.1 shuffle properly (#3340)

* add inferring_controlnet_cond_batch

* Revert "add inferring_controlnet_cond_batch"

This reverts commit abe8d6311d4b7f5b9409ca709c7fabf80d06c1a9.

* set guess_mode to True
whenever global_pool_conditions is True

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* nit

* add integration test

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [Tests] better determinism (#3374)

* enable deterministic pytorch and cuda operations.

* disable manual seeding.

* make style && make quality for unet_2d tests.

* enable determinism for the unet2dconditional model.

* add CUBLAS_WORKSPACE_CONFIG for better reproducibility.

* relax tolerance (very weird issue, though).

* revert to torch manual_seed() where needed.

* relax more tolerance.

* better placement of the cuda variable and relax more tolerance.

* enable determinism for 3d condition model.

* relax tolerance.

* add: determinism to alt_diffusion.

* relax tolerance for alt diffusion.

* dance diffusion.

* dance diffusion is flaky.

* test_dict_tuple_outputs_equivalent edit.

* fix two more tests.

* fix more ddim tests.

* fix: argument.

* change to diff in place of difference.

* fix: test_save_load call.

* test_save_load_float16 call.

* fix: expected_max_diff

* fix: paint by example.

* relax tolerance.

* add determinism to 1d unet model.

* torch 2.0 regressions seem to be brutal

* determinism to vae.

* add reason to skipping.

* up tolerance.

* determinism to vq.

* determinism to cuda.

* determinism to the generic test pipeline file.

* refactor general pipelines testing a bit.

* determinism to alt diffusion i2i

* up tolerance for alt diff i2i and audio diff

* up tolerance.

* determinism to audioldm

* increase tolerance for audioldm lms.

* increase tolerance for paint by paint.

* increase tolerance for repaint.

* determinism to cycle diffusion and sd 1.

* relax tol for cycle diffusion 🚲

* relax tol for sd 1.0

* relax tol for controlnet.

* determinism to img var.

* relax tol for img variation.

* tolerance to i2i sd

* make style

* determinism to inpaint.

* relax tolerance for inpaiting.

* determinism for inpainting legacy

* relax tolerance.

* determinism to instruct pix2pix

* determinism to model editing.

* model editing tolerance.

* panorama determinism

* determinism to pix2pix zero.

* determinism to sag.

* sd 2. determinism

* sd. tolerance

* disallow tf32 matmul.

* relax tolerance is all you need.

* make style and determinism to sd 2 depth

* relax tolerance for depth.

* tolerance to diffedit.

* tolerance to sd 2 inpaint.

* up tolerance.

* determinism in upscaling.

* tolerance in upscaler.

* more tolerance relaxation.

* determinism to v pred.

* up tol for v_pred

* unclip determinism

* determinism to unclip img2img

* determinism to text to video.

* determinism to last set of tests

* up tol.

* vq cumsum doesn't have a deterministic kernel

* relax tol

* relax tol

* [docs] Add transformers to install (#3388)

add transformers to install

* [deepspeed] partial ZeRO-3 support (#3076)

* [deepspeed] partial ZeRO-3 support

* cleanup

* improve deepspeed fixes

* Improve

* make style

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Add omegaconf for tests (#3400)

Add omegaconfg

* Fix various bugs with LoRA Dreambooth and Dreambooth script (#3353)

* Improve checkpointing lora

* fix more

* Improve doc string

* Update src/diffusers/loaders.py

* make stytle

* Apply suggestions from code review

* Update src/diffusers/loaders.py

* Apply suggestions from code review

* Apply suggestions from code review

* better

* Fix all

* Fix multi-GPU dreambooth

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fix all

* make style

* make style

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fix docker file (#3402)

* up

* up

* fix: deepseepd_plugin retrieval from accelerate state (#3410)

* [Docs] Add `sigmoid` beta_scheduler to docstrings of relevant Schedulers (#3399)

* Add `sigmoid` beta scheduler to `DDPMScheduler` docstring

* Add `sigmoid` beta scheduler to `RePaintScheduler` docstring

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Don't install accelerate and transformers from source (#3415)

* Don't install transformers and accelerate from source (#3414)

* Improve fast tests (#3416)

Update pr_tests.yml

* attention refactor: the trilogy  (#3387)

* Replace `AttentionBlock` with `Attention`

* use _from_deprecated_attn_block check re: @patrickvonplaten

* [Docs] update the PT 2.0 optimization doc with latest findings (#3370)

* add: benchmarking stats for A100 and V100.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* address patrick's comments.

* add: rtx 4090 stats

* ⚔ benchmark reports done

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* 3313 pr link.

* add: plots.

Co-authored-by: Pedro <pedro@huggingface.co>

* fix formattimg

* update number percent.

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fix style rendering (#3433)

* Fix style rendering.

* Fix typo

* unCLIP scheduler do not use note (#3417)

* Replace deprecated command with environment file (#3409)

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fix warning message pipeline loading (#3446)

* add stable diffusion tensorrt img2img pipeline (#3419)

* add stable diffusion tensorrt img2img pipeline

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>

* update docstrings

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>

---------

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>

* Refactor controlnet and add img2img and inpaint (#3386)

* refactor controlnet and add img2img and inpaint

* First draft to get pipelines to work

* make style

* Fix more

* Fix more

* More tests

* Fix more

* Make inpainting work

* make style and more tests

* Apply suggestions from code review

* up

* make style

* Fix imports

* Fix more

* Fix more

* Improve examples

* add test

* Make sure import is correctly deprecated

* Make sure everything works in compile mode

* make sure authorship is correctly attributed

* [Scheduler] DPM-Solver (++) Inverse Scheduler (#3335)

* Add DPM-Solver Multistep Inverse Scheduler

* Add draft tests for DiffEdit

* Add inverse sde-dpmsolver steps to tune image diversity from inverted latents

* Fix tests

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [Docs] Fix incomplete docstring for resnet.py (#3438)

Fix incomplete docstrings for resnet.py

* fix tiled vae blend extent range (#3384)

fix tiled vae bleand extent range

* Small update to "Next steps" section (#3443)

Small update to "Next steps" section:

- PyTorch 2 is recommended.
- Updated improvement figures.

* Allow arbitrary aspect ratio in IFSuperResolutionPipeline (#3298)

* Update pipeline_if_superresolution.py

Allow arbitrary aspect ratio in IFSuperResolutionPipeline by using the input image shape

* IFSuperResolutionPipeline: allow the user to override the height and width through the arguments

* update IFSuperResolutionPipeline width/height doc string to match StableDiffusionInpaintPipeline conventions

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Adding 'strength' parameter to StableDiffusionInpaintingPipeline  (#3424)

* Added explanation of 'strength' parameter

* Added get_timesteps function which relies on new strength parameter

* Added `strength` parameter which defaults to 1.

* Swapped ordering so `noise_timestep` can be calculated before masking the image

this is required when you aren't applying 100% noise to the masked region, e.g. strength < 1.

* Added strength to check_inputs, throws error if out of range

* Changed `prepare_latents` to initialise latents w.r.t strength

inspired from the stable diffusion img2img pipeline, init latents are initialised by converting the init image into a VAE latent and adding noise (based upon the strength parameter passed in), e.g. random when strength = 1, or the init image at strength = 0.

* WIP: Added a unit test for the new strength parameter in the StableDiffusionInpaintingPipeline

still need to add correct regression values

* Created a is_strength_max to initialise from pure random noise

* Updated unit tests w.r.t new strength parameter + fixed new strength unit test

* renamed parameter to avoid confusion with variable of same name

* Updated regression values for new strength test - now passes

* removed 'copied from' comment as this method is now different and divergent from the cpy

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Ensure backwards compatibility for prepare_mask_and_masked_image

created a return_image boolean and initialised to false

* Ensure backwards compatibility for prepare_latents

* Fixed copy check typo

* Fixes w.r.t backward compibility changes

* make style

* keep function argument ordering same for backwards compatibility in callees with copied from statements

* make fix-copies

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: William Berman <WLBberman@gmail.com>

* [WIP] Bugfix - Pipeline.from_pretrained is broken when the pipeline is partially downloaded (#3448)

Added bugfix using f strings.

* Fix gradient checkpointing bugs in freezing part of models (requires_grad=False) (#3404)

* gradient checkpointing bug fix

* bug fix; changes for reviews

* reformat

* reformat

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Make dreambooth lora more robust to orig unet (#3462)

* Make dreambooth lora more robust to orig unet

* up

* Reduce peak VRAM by releasing large attention tensors (as soon as they're unnecessary) (#3463)

Release large tensors in attention (as soon as they're no longer required). Reduces peak VRAM by nearly 2 GB for 1024x1024 (even after slicing), and the savings scale up with image size.

* Add min snr to text2img lora training script (#3459)

add min snr to text2img lora training script

* Add inpaint lora scale support (#3460)

* add inpaint lora scale support

* add inpaint lora scale test

---------

Co-authored-by: yueyang.hyy <yueyang.hyy@alibaba-inc.com>

* [From ckpt] Fix from_ckpt (#3466)

* Correct from_ckpt

* make style

* Update full dreambooth script to work with IF (#3425)

* Add IF dreambooth docs (#3470)

* parameterize pass single args through tuple (#3477)

* attend and excite tests disable determinism on the class level (#3478)

* dreambooth docs torch.compile note (#3471)

* dreambooth docs torch.compile note

* Update examples/dreambooth/README.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update examples/dreambooth/README.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* add: if entry in the dreambooth training docs. (#3472)

* [docs] Textual inversion inference (#3473)

* add textual inversion inference to docs

* add to toctree

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* [docs] Distributed inference (#3376)

* distributed inference

* move to inference section

* apply feedback

* update with split_between_processes

* apply feedback

* [{Up,Down}sample1d] explicit view kernel size as number elements in flattened indices (#3479)

explicit view kernel size as number elements in flattened indices

* mps & onnx tests rework (#3449)

* Remove ONNX tests from PR.

They are already a part of push_tests.yml.

* Remove mps tests from PRs.

They are already performed on push.

* Fix workflow name for fast push tests.

* Extract mps tests to a workflow.

For better control/filtering.

* Remove --extra-index-url from mps tests

* Increase tolerance of mps test

This test passes in my Mac (Ventura 13.3) but fails in the CI hardware
(Ventura 13.2). I ran the local tests following the same steps that
exist in the CI workflow.

* Temporarily run mps tests on pr

So we can test.

* Revert "Temporarily run mps tests on pr"

Tests passed, go back to running on push.

* [Attention processor] Better warning message when shifting to `AttnProcessor2_0` (#3457)

* add: debugging to enabling memory efficient processing

* add: better warning message.

* [Docs] add note on local directory path. (#3397)

add note on local directory path.

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Refactor full determinism (#3485)

* up

* fix more

* Apply suggestions from code review

* fix more

* fix more

* Check it

* Remove 16:8

* fix more

* fix more

* fix more

* up

* up

* Test only stable diffusion

* Test only two files

* up

* Try out spinning up processes that can be killed

* up

* Apply suggestions from code review

* up

* up

* Fix DPM single (#3413)

* Fix DPM single

* add test

* fix one more bug

* Apply suggestions from code review

Co-authored-by: StAlKeR7779 <stalkek7779@yandex.ru>

---------

Co-authored-by: StAlKeR7779 <stalkek7779@yandex.ru>

* Add `use_Karras_sigmas` to DPMSolverSinglestepScheduler (#3476)

* add use_karras_sigmas

* add karras test

* add doc

* Adds local_files_only bool to prevent forced online connection (#3486)

* make style

* [Docs] Korean translation (optimization, training) (#3488)

* feat) optimization kr translation

* fix) typo, italic setting

* feat) dreambooth, text2image kr

* feat) lora kr

* fix) LoRA

* fix) fp16 fix

* fix) doc-builder style

* fix) fp16 일부 단어 수정

* fix) fp16 style fix

* fix) opt, training docs update

* feat) toctree update

* feat) toctree update

---------

Co-authored-by: Chanran Kim <seriousran@gmail.com>

* DataLoader respecting EXIF data in Training Images (#3465)

* DataLoader will now bake in any transforms or image manipulations contained in the EXIF

Images may have rotations stored in EXIF. Training using such images will cause those transforms to be ignored while training and thus produce unexpected results

* Fixed the Dataloading EXIF issue in main DreamBooth training as well

* Run make style (black & isort)

* make style

* feat: allow disk offload for diffuser models (#3285)

* allow disk offload for diffuser models

* sort import

* add max_memory argument

* Changed sample[0] to images[0] (#3304)

A pipeline object stores the results in `images` not in `sample`.
Current code blocks don't work.

* Typo in tutorial (#3295)

* Torch compile graph fix (#3286)

* fix more

* Fix more

* fix more

* Apply suggestions from code review

* fix

* make style

* make fix-copies

* fix

* make sure torch compile

* Clean

* fix test

* Postprocessing refactor img2img (#3268)

* refactor img2img VaeImageProcessor.postprocess

* remove copy from for init, run_safety_checker, decode_latents

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

---------

Co-authored-by: yiyixuxu <yixu@yis-macbook-pro.lan>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* [Torch 2.0 compile] Fix more torch compile breaks (#3313)

* Fix more torch compile breaks

* add tests

* Fix all

* fix controlnet

* fix more

* Add Horace He as co-author.
>
>
Co-authored-by: Horace He <horacehe2007@yahoo.com>

* Add Horace He as co-author.

Co-authored-by: Horace He <horacehe2007@yahoo.com>

---------

Co-authored-by: Horace He <horacehe2007@yahoo.com>

* fix: scale_lr and sync example readme and docs. (#3299)

* fix: scale_lr and sync example readme and docs.

* fix doc link.

* Update stable_diffusion.mdx (#3310)

fixed import statement

* Fix missing variable assign in DeepFloyd-IF-II (#3315)

Fix missing variable assign

lol

* Correct doc build for patch releases (#3316)

Update build_documentation.yml

* Add Stable Diffusion RePaint to community pipelines (#3320)

* Add Stable Diffsuion RePaint to community pipelines

- Adds Stable Diffsuion RePaint to community pipelines
- Add Readme enty for pipeline

* Fix: Remove wrong import

- Remove wrong import
- Minor change in comments

* Fix: Code formatting of stable_diffusion_repaint

* Fix: ruff errors in stable_diffusion_repaint

* Fix multistep dpmsolver for cosine schedule (suitable for deepfloyd-if) (#3314)

* fix multistep dpmsolver for cosine schedule (deepfloy-if)

* fix a typo

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* update all dpmsolver (singlestep, multistep, dpm, dpm++) for cosine noise schedule

* add test, fix style

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [docs] Improve LoRA docs (#3311)

* update docs

* add to toctree

* apply feedback

* Added input pretubation (#3292)

* Added input pretubation

* Fixed spelling

* Update write_own_pipeline.mdx (#3323)

* update controlling generation doc with latest goodies. (#3321)

* [Quality] Make style (#3341)

* Fix config dpm (#3343)

* Add the SDE variant of DPM-Solver and DPM-Solver++ (#3344)

* add SDE variant of DPM-Solver and DPM-Solver++

* add test

* fix typo

* fix typo

* Add upsample_size to AttnUpBlock2D, AttnDownBlock2D (#3275)

The argument `upsample_size` needs to be added to these modules to allow compatibility with other blocks that require this argument.

* Rename --only_save_embeds to --save_as_full_pipeline (#3206)

* Set --only_save_embeds to False by default

Due to how the option is named, it makes more sense to behave like this.

* Refactor only_save_embeds to save_as_full_pipeline

* [AudioLDM] Generalise conversion script (#3328)

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Fix TypeError when using prompt_embeds and negative_prompt (#2982)

* test: Added test case

* fix: fixed type checking issue on _encode_prompt

* fix: fixed copies consistency

* fix: one copy was not sufficient

* Fix pipeline class on README (#3345)

Update README.md

* Inpainting: typo in docs (#3331)

Typo in docs

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Add `use_Karras_sigmas` to LMSDiscreteScheduler (#3351)

* add karras sigma to lms discrete scheduler

* add test for lms_scheduler karras

* reformat test lms

* Batched load of textual inversions (#3277)

* Batched load of textual inversions

- Only call resize_token_embeddings once per batch as it is the most expensive operation
- Allow pretrained_model_name_or_path and token to be an optional list
- Remove Dict from type annotation pretrained_model_name_or_path as it was not supported in this function
- Add comment that single files (e.g. .pt/.safetensors) are supported
- Add comment for token parameter
- Convert token override log message from warning to info

* Update src/diffusers/loaders.py

Check for duplicate tokens

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update condition for None tokens

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make fix-copies

* [docs] Fix docstring (#3334)

fix docstring

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* if dreambooth lora (#3360)

* update IF stage I pipelines

add fixed variance schedulers and lora loading

* added kv lora attn processor

* allow loading into alternative lora attn processor

* make vae optional

* throw away predicted variance

* allow loading into added kv lora layer

* allow load T5

* allow pre compute text embeddings

* set new variance type in schedulers

* fix copies

* refactor all prompt embedding code

class prompts are now included in pre-encoding code
max tokenizer length is now configurable
embedding attention mask is now configurable

* fix for when variance type is not defined on scheduler

* do not pre compute validation prompt if not present

* add example test for if lora dreambooth

* add check for train text encoder and pre compute text embeddings

* Postprocessing refactor all others (#3337)

* add text2img

* fix-copies

* add

* add all other pipelines

* add

* add

* add

* add

* add

* make style

* style + fix copies

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>

* [docs] Improve safetensors docstring (#3368)

* clarify safetensor docstring

* fix typo

* apply feedback

* add: a warning message when using xformers in a PT 2.0 env. (#3365)

* add: a warning message when using xformers in a PT 2.0 env.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* StableDiffusionInpaintingPipeline - resize image w.r.t height and width (#3322)

* StableDiffusionInpaintingPipeline now resizes input images and masks w.r.t to passed input height and width. Default is already set to 512. This addresses the common tensor mismatch error. Also moved type check into relevant funciton to keep main pipeline body tidy.

* Fixed StableDiffusionInpaintingPrepareMaskAndMaskedImageTests

Due to previous commit these tests were failing as height and width need to be passed into the prepare_mask_and_masked_image function, I have updated the code and added a height/width variable per unit test as it seemed more appropriate than the current hard coded solution

* Added a resolution test to StableDiffusionInpaintPipelineSlowTests

this unit test simply gets the input and resizes it into some that would fail (e.g. would throw a tensor mismatch error/not a mult of 8). Then passes it through the pipeline and verifies it produces output with correct dims w.r.t the passed height and width

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make style

* [docs] Adapt a model (#3326)

* first draft

* apply feedback

* conv_in.weight thrown away

* [docs] Load safetensors (#3333)

* safetensors

* apply feedback

* apply feedback

* Apply suggestions from code review

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make style

* [Docs] Fix stable_diffusion.mdx typo (#3398)

Fix typo in last code block. Correct "prommpts" to "prompt"

* Support ControlNet v1.1 shuffle properly (#3340)

* add inferring_controlnet_cond_batch

* Revert "add inferring_controlnet_cond_batch"

This reverts commit abe8d6311d4b7f5b9409ca709c7fabf80d06c1a9.

* set guess_mode to True
whenever global_pool_conditions is True

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* nit

* add integration test

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [Tests] better determinism (#3374)

* enable deterministic pytorch and cuda operations.

* disable manual seeding.

* make style && make quality for unet_2d tests.

* enable determinism for the unet2dconditional model.

* add CUBLAS_WORKSPACE_CONFIG for better reproducibility.

* relax tolerance (very weird issue, though).

* revert to torch manual_seed() where needed.

* relax more tolerance.

* better placement of the cuda variable and relax more tolerance.

* enable determinism for 3d condition model.

* relax tolerance.

* add: determinism to alt_diffusion.

* relax tolerance for alt diffusion.

* dance diffusion.

* dance diffusion is flaky.

* test_dict_tuple_outputs_equivalent edit.

* fix two more tests.

* fix more ddim tests.

* fix: argument.

* change to diff in place of difference.

* fix: test_save_load call.

* test_save_load_float16 call.

* fix: expected_max_diff

* fix: paint by example.

* relax tolerance.

* add determinism to 1d unet model.

* torch 2.0 regressions seem to be brutal

* determinism to vae.

* add reason to skipping.

* up tolerance.

* determinism to vq.

* determinism to cuda.

* determinism to the generic test pipeline file.

* refactor general pipelines testing a bit.

* determinism to alt diffusion i2i

* up tolerance for alt diff i2i and audio diff

* up tolerance.

* determinism to audioldm

* increase tolerance for audioldm lms.

* increase tolerance for paint by paint.

* increase tolerance for repaint.

* determinism to cycle diffusion and sd 1.

* relax tol for cycle diffusion 🚲

* relax tol for sd 1.0

* relax tol for controlnet.

* determinism to img var.

* relax tol for img variation.

* tolerance to i2i sd

* make style

* determinism to inpaint.

* relax tolerance for inpaiting.

* determinism for inpainting legacy

* relax tolerance.

* determinism to instruct pix2pix

* determinism to model editing.

* model editing tolerance.

* panorama determinism

* determinism to pix2pix zero.

* determinism to sag.

* sd 2. determinism

* sd. tolerance

* disallow tf32 matmul.

* relax tolerance is all you need.

* make style and determinism to sd 2 depth

* relax tolerance for depth.

* tolerance to diffedit.

* tolerance to sd 2 inpaint.

* up tolerance.

* determinism in upscaling.

* tolerance in upscaler.

* more tolerance relaxation.

* determinism to v pred.

* up tol for v_pred

* unclip determinism

* determinism to unclip img2img

* determinism to text to video.

* determinism to last set of tests

* up tol.

* vq cumsum doesn't have a deterministic kernel

* relax tol

* relax tol

* [docs] Add transformers to install (#3388)

add transformers to install

* [deepspeed] partial ZeRO-3 support (#3076)

* [deepspeed] partial ZeRO-3 support

* cleanup

* improve deepspeed fixes

* Improve

* make style

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Add omegaconf for tests (#3400)

Add omegaconfg

* Fix various bugs with LoRA Dreambooth and Dreambooth script (#3353)

* Improve checkpointing lora

* fix more

* Improve doc string

* Update src/diffusers/loaders.py

* make stytle

* Apply suggestions from code review

* Update src/diffusers/loaders.py

* Apply suggestions from code review

* Apply suggestions from code review

* better

* Fix all

* Fix multi-GPU dreambooth

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fix all

* make style

* make style

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fix docker file (#3402)

* up

* up

* fix: deepseepd_plugin retrieval from accelerate state (#3410)

* [Docs] Add `sigmoid` beta_scheduler to docstrings of relevant Schedulers (#3399)

* Add `sigmoid` beta scheduler to `DDPMScheduler` docstring

* Add `sigmoid` beta scheduler to `RePaintScheduler` docstring

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Don't install accelerate and transformers from source (#3415)

* Don't install transformers and accelerate from source (#3414)

* Improve fast tests (#3416)

Update pr_tests.yml

* attention refactor: the trilogy  (#3387)

* Replace `AttentionBlock` with `Attention`

* use _from_deprecated_attn_block check re: @patrickvonplaten

* [Docs] update the PT 2.0 optimization doc with latest findings (#3370)

* add: benchmarking stats for A100 and V100.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* address patrick's comments.

* add: rtx 4090 stats

* ⚔ benchmark reports done

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* 3313 pr link.

* add: plots.

Co-authored-by: Pedro <pedro@huggingface.co>

* fix formattimg

* update number percent.

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fix style rendering (#3433)

* Fix style rendering.

* Fix typo

* unCLIP scheduler do not use note (#3417)

* Replace deprecated command with environment file (#3409)

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* …
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Fix a bug of pano when not doing CFG (#3030)

* Fix a bug of pano when not doing CFG

* enhance code quality

* apply formatting.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Text2video zero refinements (#3070)

* fix progress bar issue in pipeline_text_to_video_zero.py. Copy scheduler after first backward

* fix tensor loading in test_text_to_video_zero.py

* make style && make quality

* Release: v0.15.0

* [Tests] Speed up panorama tests (#3067)

* fix: norm group test for UNet3D.

* chore: speed up the panorama tests (fast).

* set default value of _test_inference_batch_single_identical.

* fix: batch_sizes default value.

* [Post release] v0.16.0dev (#3072)

* Adds profiling flags, computes train metrics average. (#3053)

* WIP controlnet training

- bugfix --streaming
- bugfix running report_to!='wandb'
- adds memory profile before validation

* Adds final logging statement.

* Sets train epochs to 11.

Looking at a longer ~16ep run, we see only good validation images
after ~11ep:

https://wandb.ai/andsteing/controlnet_fill50k/runs/3j2hx6n8

* Removes --logging_dir (it's not used).

* Adds --profile flags.

* Updates --output_dir=runs/fill-circle-{timestamp}.

* Compute mean of `train_metrics`.

Previously `train_metrics[-1]` was logged, resulting in very bumpy train
metrics.

* Improves logging a bit.

- adds l2_grads gradient norm logging
- adds steps_per_sec
- sets walltime as x coordinate of train/step
- logs controlnet_params config

* Adds --ccache (doesn't really help though).

* minor fix in controlnet flax example (#2986)

* fix the error when push_to_hub but not log validation

* contronet_from_pt & controlnet_revision

* add intermediate checkpointing to the guide

* Bugfix --profile_steps

* Sets `RACKER_PROJECT_NAME='controlnet_fill50k'`.

* Logs fractional epoch.

* Adds relative `walltime` metric.

* Adds `StepTraceAnnotation` and uses `global_step` insetad of `step`.

* Applied `black`.

* Streamlines commands in README a bit.

* Removes `--ccache`.

This makes only a very small difference (~1 min) with this model size, so removing
the option introduced in cdb3cc.

* Re-ran `black`.

* Update examples/controlnet/README.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Converts spaces to tab.

* Removes repeated args.

* Skips first step (compilation) in profiling

* Updates README with profiling instructions.

* Unifies tabs/spaces in README.

* Re-ran style & quality.

---------
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* [Pipelines] Make sure that None functions are correctly not saved (#3080)

* doc string example remove from_pt (#3083)

* [Tests] parallelize (#3078)

* [Tests] parallelize

* finish folder structuring

* Parallelize tests more

* Correct saving of pipelines

* make sure logging level is correct

* try again

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Throw deprecation warning for return_cached_folder (#3092)

Throw deprecation warning

* Allow SD attend and excite pipeline to work with any size output images (#2835)

Allow stable diffusion attend and excite pipeline to work with any size output image. Re: #2476, #2603

* [docs] Update community pipeline docs (#2989)

* update community pipeline docs

* fix formatting

* explain sharing workflows

* Add to support Guess Mode for StableDiffusionControlnetPipleline (#2998)

* add guess mode (WIP)

* fix uncond/cond order

* support guidance_scale=1.0 and batch != 1

* remove magic coeff

* add docstring

* add intergration test

* add document to controlnet.mdx

* made the comments a bit more explanatory

* fix table

* fix default value for attend-and-excite (#3099)

* fix default

* remvoe one line as requested by gc team  (#3077)

remvoe one line

* ddpm custom timesteps (#3007)

add custom timesteps test

add custom timesteps descending order check

docs

timesteps -> custom_timesteps

can only pass one of num_inference_steps and timesteps

* Fix breaking change in `pipeline_stable_diffusion_controlnet.py` (#3118)

fix breaking change

* Add global pooling to controlnet (#3121)

* [Bug fix] Fix img2img processor with safety checker (#3127)

Fix img2img processor with safety checker

* [Bug fix] Make sure correct timesteps are chosen for img2img (#3128)

Make sure correct timesteps are chosen for img2img

* Improve deprecation warnings (#3131)

* Fix config deprecation (#3129)

* Better deprecation message

* Better deprecation message

* Better doc string

* Fixes

* fix more

* fix more

* Improve __getattr__

* correct more

* fix more

* fix

* Improve more

* more improvements

* fix more

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* make style

* Fix all rest & add tests & remove old deprecation fns

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* feat: verfication of multi-gpu support for select examples. (#3126)

* feat: verfication of multi-gpu support for select examples.

* add: multi-gpu training sections to the relvant doc pages.

* speed up attend-and-excite fast tests (#3079)

* Optimize log_validation in train_controlnet_flax (#3110)

extract pipeline from log_validation

* make style

* Correct textual inversion readme (#3145)

* Update README.md

* Apply suggestions from code review

* Add unet act fn to other model components (#3136)

Adding act fn config to the unet timestep class embedding and conv
activation.

The custom activation defaults to silu which is the default
activation function for both the conv act and the timestep class
embeddings so default behavior is not changed.

The only unet which use the custom activation is the stable diffusion
latent upscaler https://huggingface.co/stabilityai/sd-x2-latent-upscaler/blob/main/unet/config.json
(I ran a script against the hub to confirm).
The latent upscaler does not use the conv activation nor the timestep
class embeddings so we don't change its behavior.

* class labels timestep embeddings projection dtype cast (#3137)

This mimics the dtype cast for the standard time embeddings

* [ckpt loader] Allow loading the Inpaint and Img2Img pipelines, while loading a ckpt model (#2705)

* [ckpt loader] Allow loading the Inpaint and Img2Img pipelines, while loading a ckpt model

* Address review comment from PR

* PyLint formatting

* Some more pylint fixes, unrelated to our change

* Another pylint fix

* Styling fix

* add from_ckpt method as Mixin (#2318)

* add mixin class for pipeline from original sd ckpt

* Improve

* make style

* merge main into

* Improve more

* fix more

* up

* Apply suggestions from code review

* finish docs

* rename

* make style

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Add TensorRT SD/txt2img Community Pipeline to diffusers along with TensorRT utils (#2974)

* Add SD/txt2img Community Pipeline to diffusers along with TensorRT utils

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>

* update installation command

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>

* update tensorrt installation

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>

* changes
1. Update setting of cache directory
2. Address comments: merge utils and pipeline code.
3. Address comments: Add section in README

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>

* apply make style

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>

---------

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Correct `Transformer2DModel.forward` docstring (#3074)

⚙️chore(transformer_2d) update function signature for encoder_hidden_states

* Update pipeline_stable_diffusion_inpaint_legacy.py (#2903)

* Update pipeline_stable_diffusion_inpaint_legacy.py

* fix preprocessing of Pil images with adequate batch size

* revert map

* add tests

* reformat

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* next try to fix the style

* wth is this

* Update testing_utils.py

* Update testing_utils.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

* Update test_stable_diffusion_inpaint_legacy.py

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Modified altdiffusion pipline to support altdiffusion-m18 (#2993)

* Modified altdiffusion pipline to support altdiffusion-m18

* Modified altdiffusion pipline to support altdiffusion-m18

* Modified altdiffusion pipline to support altdiffusion-m18

* Modified altdiffusion pipline to support altdiffusion-m18

* Modified altdiffusion pipline to support altdiffusion-m18

* Modified altdiffusion pipline to support altdiffusion-m18

* Modified altdiffusion pipline to support altdiffusion-m18

---------

Co-authored-by: root <fulong_ye@163.com>

* controlnet training resize inputs to multiple of 8 (#3135)

controlnet training center crop input images to multiple of 8

The pipeline code resizes inputs to multiples of 8.
Not doing this resizing in the training script is causing
the encoded image to have different height/width dimensions
than the encoded conditioning image (which uses a separate
encoder that's part of the controlnet model).

We resize and center crop the inputs to make sure they're the
same size (as well as all other images in the batch). We also
check that the initial resolution is a multiple of 8.

* adding custom diffusion training to diffusers examples (#3031)

* diffusers==0.14.0 update

* custom diffusion update

* custom diffusion update

* custom diffusion update

* custom diffusion update

* custom diffusion update

* custom diffusion update

* custom diffusion

* custom diffusion

* custom diffusion

* custom diffusion

* custom diffusion

* apply formatting and get rid of bare except.

* refactor readme and other minor changes.

* misc refactor.

* fix: repo_id issue and loaders logging bug.

* fix: save_model_card.

* fix: save_model_card.

* fix: save_model_card.

* add: doc entry.

* refactor doc,.

* custom diffusion

* custom diffusion

* custom diffusion

* apply style.

* remove tralining whitespace.

* fix: toctree entry.

* remove unnecessary print.

* custom diffusion

* custom diffusion

* custom diffusion test

* custom diffusion xformer update

* custom diffusion xformer update

* custom diffusion xformer update

---------

Co-authored-by: Nupur Kumari <nupurkumari@Nupurs-MacBook-Pro.local>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Nupur Kumari <nupurkumari@nupurs-mbp.wifi.local.cmu.edu>

* make style

* Update custom_diffusion.mdx (#3165)

Add missing newlines for rendering the links correctly

* Added distillation for quantization example on textual inversion. (#2760)

* Added distillation for quantization example on textual inversion.

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>

* refined readme and code style.

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>

* Update text2images.py

* refined code of model load and added compatibility check.

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>

* fixed code style.

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>

* fix C403 [*] Unnecessary `list` comprehension (rewrite as a `set` comprehension)

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>

---------

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>

* Update Noise Autocorrelation Loss Function for Pix2PixZero Pipeline (#2942)

* Update Pix2PixZero Auto-correlation Loss

* Add fast inversion tests

* Clarify purpose and mark as deprecated

Fix inversion prompt broadcasting

* Register modules set to `None` in config for `test_save_load_optional_components`

* Update new tests to coordinate with #2953

* [DreamBooth] add text encoder LoRA support in the DreamBooth training script (#3130)

* add: LoRA text encoder support for DreamBooth example.

* fix initialization.

* fix: modification call.

* add: entry in the readme.

* use dog dataset from hub.

* fix: params to clip.

* add entry to the LoRA doc.

* add: tests for lora.

* remove unnecessary list comprehension./

* Update Habana Gaudi documentation (#3169)

* Update Habana Gaudi doc

* Fix tables

* Add model offload to x4 upscaler (#3187)

* Add model offload to x4 upscaler

* fix

* [docs] Deterministic algorithms (#3172)

deterministic algos

* Update custom_diffusion.mdx to credit the author (#3163)

* Update custom_diffusion.mdx

* fix: unnecessary list comprehension.

* Fix TensorRT community pipeline device set function (#3157)

pass silence_dtype_warnings as kwarg

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make `from_flax` work for controlnet (#3161)

fix from_flax

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [docs] Clarify training args (#3146)

* clarify training arg

* apply feedback

* Multi Vector Textual Inversion (#3144)

* Multi Vector

* Improve

* fix multi token

* improve test

* make style

* Update examples/test_examples.py

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* update

* Finish

* Apply suggestions from code review

---------

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Add `Karras sigmas` to HeunDiscreteScheduler (#3160)

* Add karras pattern to discrete heun scheduler

* Add integration test

* Fix failing CI on pytorch test on M1 (mps)

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [AudioLDM] Fix dtype of returned waveform (#3189)

* Fix bug in train_dreambooth_lora (#3183)

* Update train_dreambooth_lora.py

fix bug

* Update train_dreambooth_lora.py

* [Community Pipelines] Update lpw_stable_diffusion pipeline (#3197)

* Update lpw_stable_diffusion.py

* fix cpu offload

* Make sure VAE attention works with Torch 2_0 (#3200)

* Make sure attention works with Torch 2_0

* make style

* Fix more

* Revert "[Community Pipelines] Update lpw_stable_diffusion pipeline" (#3201)

Revert "[Community Pipelines] Update lpw_stable_diffusion pipeline (#3197)"

This reverts commit 9965cb50eac12e397473f01535aab43aae76b4ab.

* [Bug fix] Fix batch size attention head size mismatch (#3214)

* fix mixed precision training on train_dreambooth_inpaint_lora (#3138)

cast to weight dtype

* adding enable_vae_tiling and disable_vae_tiling functions (#3225)

adding enable_vae_tiling and disable_val_tiling functions

* Add ControlNet v1.1 docs (#3226)

Add v1.1 docs

* Fix issue in maybe_convert_prompt (#3188)

When the token used for textual inversion does not have any special symbols (e.g. it is not surrounded by <>), the tokenizer does not properly split the replacement tokens.  Adding a space for the padding tokens fixes this.

* Sync cache version check from transformers (#3179)

sync cache version check from transformers

* Fix docs text inversion (#3166)

* Fix docs text inversion

* Apply suggestions from code review

* add model (#3230)

* add

* clean

* up

* clean up more

* fix more tests

* Improve docs further

* improve

* more fixes docs

* Improve docs more

* Update src/diffusers/models/unet_2d_condition.py

* fix

* up

* update doc links

* make fix-copies

* add safety checker and watermarker to stage 3 doc page code snippets

* speed optimizations docs

* memory optimization docs

* make style

* add watermarking snippets to doc string examples

* make style

* use pt_to_pil helper functions in doc strings

* skip mps tests

* Improve safety

* make style

* new logic

* fix

* fix bad onnx design

* make new stable diffusion upscale pipeline model arguments optional

* define has_nsfw_concept when non-pil output type

* lowercase linked to notebook name

---------

Co-authored-by: William Berman <WLBberman@gmail.com>

* Allow return pt x4 (#3236)

* Add all files

* update

* Allow fp16 attn for x4 upscaler (#3239)

* Add all files

* update

* Make sure vae is memory efficient for PT 1

* make style

* fix fast test (#3241)

* Adds a document on token merging (#3208)

* add document on token merging.

* fix headline.

* fix: headline.

* add some samples for comparison.

* [AudioLDM] Update docs to use updated ckpt (#3240)

* [AudioLDM] Update docs to use updated ckpt

* make style

* Release: v0.16.0

* Post release for 0.16.0 (#3244)

* Post release

* fix more

* [docs] only mention one stage (#3246)

* [docs] only mention one stage

* add blurb on auto accepting

---------

Co-authored-by: William Berman <WLBberman@gmail.com>

* Write model card in controlnet training script (#3229)

Write model card in controlnet training script.

* [2064]: Add stochastic sampler (sample_dpmpp_sde) (#3020)

* [2064]: Add stochastic sampler

* [2064]: Add stochastic sampler

* [2064]: Add stochastic sampler

* [2064]: Add stochastic sampler

* [2064]: Add stochastic sampler

* [2064]: Add stochastic sampler

* [2064]: Add stochastic sampler

* Review comments

* [Review comment]: Add is_torchsde_available()

* [Review comment]: Test and docs

* [Review comment]

* [Review comment]

* [Review comment]

* [Review comment]

* [Review comment]

---------

Co-authored-by: njindal <njindal@adobe.com>

* [Stochastic Sampler][Slow Test]: Cuda test fixes (#3257)

[Slow Test]: Cuda test fixes

Co-authored-by: njindal <njindal@adobe.com>

* Remove required from tracker_project_name (#3260)

Remove required from tracker_project_name.

As observed by https://github.com/off99555 in https://github.com/huggingface/diffusers/issues/2695#issuecomment-1470755050, it already has a default value.

* adding required parameters while calling the get_up_block and get_down_block  (#3210)

* removed unnecessary parameters from get_up_block and get_down_block functions

* adding resnet_skip_time_act, resnet_out_scale_factor and cross_attention_norm to get_up_block and get_down_block functions

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* [docs] Update interface in repaint.mdx (#3119)

Update repaint.mdx

accomodate to #1701

* Update IF name to XL (#3262)

Co-authored-by: multimodalart <joaopaulo.passos+multimodal@gmail.com>

* fix typo in score sde pipeline (#3132)

* Fix typo in textual inversion JAX training script (#3123)

The pipeline is built as `pipe` but then used as `pipeline`.

* AudioDiffusionPipeline - fix encode method after config changes (#3114)

* config fixes

* deprecate get_input_dims

* Revert "Revert "[Community Pipelines] Update lpw_stable_diffusion pipeline"" (#3265)

Revert "Revert "[Community Pipelines] Update lpw_stable_diffusion pipeline" (#3201)"

This reverts commit 91a2a80eb2f98a9f64b9e287715add244dc6f2f3.

* Fix community pipelines (#3266)

* update notebook (#3259)

Co-authored-by: yiyixuxu <yixu@yis-macbook-pro.lan>

* [docs] add notes for stateful model changes (#3252)

* [docs] add notes for stateful model changes

* Update docs/source/en/optimization/fp16.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* link to accelerate docs for discarding hooks

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* [LoRA] quality of life improvements in the loading semantics and docs (#3180)

* 👽 qol improvements for LoRA.

* better function name?

* fix: LoRA weight loading with the new format.

* address Patrick's comments.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* change wording around encouraging the use of load_lora_weights().

* fix: function name.

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [Community Pipelines] EDICT pipeline implementation (#3153)

* EDICT pipeline initial commit

- Starting point taking from https://github.com/Joqsan/edict-diffusion

* refactor __init__() method

* minor refactoring

* refactor scheduler code

- remove scheduler and move its methods to the EDICTPipeline class

* make CFG optional
- refactor encode_prompt().
- include optional generator for sampling with vae.
- minor variable renaming

* add EDICT pipeline description to README.md

* replace preprocess() with VaeImageProcessor

* run make style and make quality commands

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [Docs]zh translated docs update (#3245)

* zh translated docs update

* update _toctree

* Update logging.mdx (#2863)

Fix typos

* Add multiple conditions to StableDiffusionControlNetInpaintPipeline (#3125)

* try multi controlnet inpaint

* multi controlnet inpaint

* multi controlnet inpaint

* Let's make sure that dreambooth always uploads to the Hub (#3272)

* Update Dreambooth README

* Adapt all docs as well

* automatically write model card

* fix

* make style

* Diffedit Zero-Shot Inpainting Pipeline (#2837)

* Update Pix2PixZero Auto-correlation Loss

* Add Stable Diffusion DiffEdit pipeline

* Add draft documentation and import code

* Bugfixes and refactoring

* Add option to not decode latents in the inversion process

* Harmonize preprocessing

* Revert "Update Pix2PixZero Auto-correlation Loss"

This reverts commit b218062fed08d6cc164206d6cb852b2b7b00847a.

* Update annotations

* rename `compute_mask` to `generate_mask`

* Update documentation

* Update docs

* Update Docs

* Fix copy

* Change shape of output latents to batch first

* Update docs

* Add first draft for tests

* Bugfix and update tests

* Add `cross_attention_kwargs` support for all pipeline methods

* Fix Copies

* Add support for PIL image latents

Add support for mask broadcasting

Update docs and tests

Align `mask` argument to `mask_image`

Remove height and width arguments

* Enable MPS Tests

* Move example docstrings

* Fix test

* Fix test

* fix pipeline inheritance

* Harmonize `prepare_image_latents` with StableDiffusionPix2PixZeroPipeline

* Register modules set to `None` in config for `test_save_load_optional_components`

* Move fixed logic to specific test class

* Clean changes to other pipelines

* Update new tests to coordinate with #2953

* Update slow tests for better results

* Safety to avoid potential problems with torch.inference_mode

* Add reference in SD Pipeline Overview

* Fix tests again

* Enforce determinism in noise for generate_mask

* Fix copies

* Widen test tolerance for fp16 based on `test_stable_diffusion_upscale_pipeline_fp16`

* Add LoraLoaderMixin and update `prepare_image_latents`

* clean up repeat and reg

* bugfix

* Remove invalid args from docs

Suppress spurious warning by repeating image before latent to mask gen

* add constant learning rate with custom rule (#3133)

* add constant lr with rules

* add constant with rules in TYPE_TO_SCHEDULER_FUNCTION

* add constant lr rate with rule

* hotfix code quality

* fix doc style

* change name constant_with_rules to piecewise constant

* Allow disabling torch 2_0 attention (#3273)

* Allow disabling torch 2_0 attention

* make style

* Update src/diffusers/models/attention.py

* [doc] add link to training script (#3271)

add link to training script

Co-authored-by: yiyixuxu <yixu@yis-macbook-pro.lan>

* temp disable spectogram diffusion tests (#3278)

The note-seq package throws an error on import because the default installed version of Ipython
is not compatible with python 3.8 which we run in the CI.
https://github.com/huggingface/diffusers/actions/runs/4830121056/jobs/8605954838#step:7:9

* Changed sample[0] to images[0] (#3304)

A pipeline object stores the results in `images` not in `sample`.
Current code blocks don't work.

* Typo in tutorial (#3295)

* Torch compile graph fix (#3286)

* fix more

* Fix more

* fix more

* Apply suggestions from code review

* fix

* make style

* make fix-copies

* fix

* make sure torch compile

* Clean

* fix test

* Postprocessing refactor img2img (#3268)

* refactor img2img VaeImageProcessor.postprocess

* remove copy from for init, run_safety_checker, decode_latents

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

---------

Co-authored-by: yiyixuxu <yixu@yis-macbook-pro.lan>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* [Torch 2.0 compile] Fix more torch compile breaks (#3313)

* Fix more torch compile breaks

* add tests

* Fix all

* fix controlnet

* fix more

* Add Horace He as co-author.
>
>
Co-authored-by: Horace He <horacehe2007@yahoo.com>

* Add Horace He as co-author.

Co-authored-by: Horace He <horacehe2007@yahoo.com>

---------

Co-authored-by: Horace He <horacehe2007@yahoo.com>

* fix: scale_lr and sync example readme and docs. (#3299)

* fix: scale_lr and sync example readme and docs.

* fix doc link.

* Update stable_diffusion.mdx (#3310)

fixed import statement

* Fix missing variable assign in DeepFloyd-IF-II (#3315)

Fix missing variable assign

lol

* Correct doc build for patch releases (#3316)

Update build_documentation.yml

* Add Stable Diffusion RePaint to community pipelines (#3320)

* Add Stable Diffsuion RePaint to community pipelines

- Adds Stable Diffsuion RePaint to community pipelines
- Add Readme enty for pipeline

* Fix: Remove wrong import

- Remove wrong import
- Minor change in comments

* Fix: Code formatting of stable_diffusion_repaint

* Fix: ruff errors in stable_diffusion_repaint

* Fix multistep dpmsolver for cosine schedule (suitable for deepfloyd-if) (#3314)

* fix multistep dpmsolver for cosine schedule (deepfloy-if)

* fix a typo

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* update all dpmsolver (singlestep, multistep, dpm, dpm++) for cosine noise schedule

* add test, fix style

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [docs] Improve LoRA docs (#3311)

* update docs

* add to toctree

* apply feedback

* Added input pretubation (#3292)

* Added input pretubation

* Fixed spelling

* Update write_own_pipeline.mdx (#3323)

* update controlling generation doc with latest goodies. (#3321)

* [Quality] Make style (#3341)

* Fix config dpm (#3343)

* Add the SDE variant of DPM-Solver and DPM-Solver++ (#3344)

* add SDE variant of DPM-Solver and DPM-Solver++

* add test

* fix typo

* fix typo

* Add upsample_size to AttnUpBlock2D, AttnDownBlock2D (#3275)

The argument `upsample_size` needs to be added to these modules to allow compatibility with other blocks that require this argument.

* Add UniDiffuser classes to __init__ files, modify transformer block to support pre- and post-LN, add fast default tests, fix some bugs.

* Update fast tests to use test checkpoints stored on the hub and to better match the reference UniDiffuser implementation.

* Fix code with make style.

* Revert "Fix code style with make style."

This reverts commit 10a174a12c82e6abd3d5a57665719a03dbb85ca7.

* Add self.image_encoder, self.text_decoder to list of models to offload to CPU in the enable_sequential_cpu_offload(...)/enable_model_cpu_offload(...) methods to make test_cpu_offload_forward_pass pass.

* Fix code quality with make style.

* Support using a data type embedding for UniDiffuser-v1.

* Add fast test for checking UniDiffuser-v1 sampling.

* Make changes so that the repository consistency tests pass.

* Add UniDiffuser dummy objects via make fix-copies.

* Fix bugs and make improvements to the UniDiffuser pipeline:
	- Improve batch size inference and fix bugs when num_images_per_prompt or num_prompts_per_image > 1
	- Add tests for num_images_per_prompt, num_prompts_per_image > 1
	- Improve check_inputs, especially regarding checking supplied latents
	- Add reset_mode method so that mode inference can be re-enabled after mode is set manually
	- Fix some warnings related to accessing class members directly instead of through their config
	- Small amount of refactoring in pipeline_unidiffuser.py

* Fix code style with make style.

* Add/edit docstrings for added classes and public pipeline methods. Also do some light refactoring.

* Add documentation for UniDiffuser and fix some typos/formatting in docstrings.

* Fix code with make style.

* Refactor and improve the UniDiffuser convert_from_ckpt.py script.

* Move the UniDiffusers convert_from_ckpy.py script to diffusers/scripts/convert_unidiffuser_to_diffusers.py

* Fix code quality via make style.

* Improve UniDiffuser slow tests.

* make style

* Fix some typos in the UniDiffuser docs.

* Remove outdated logic based on transformers version in UniDiffuser pipeline __init__.py

* Remove dependency on einops by refactoring einops operations to pure torch operations.

* make style

* Add slow test on full checkpoint for joint mode and correct expected image slices/text prefixes.

* make style

* Fix mixed precision issue by wrapping the offending code with the torch.autocast context manager.

* Revert "Fix mixed precision issue by wrapping the offending code with the torch.autocast context manager."

This reverts commit 1a58958ab4f024dbc4c90a6404c2e66210db6d00.

* Add fast test for CUDA/fp16 model behavior (currently failing).

* Fix the mixed precision issue and add additional tests of the pipeline cuda/fp16 functionality.

* make style

* Use a CLIPVisionModelWithProjection instead of CLIPVisionModel for image_encoder to better match the original UniDiffuser implementation.

* Make style and remove some testing code.

* Fix shape errors for the 'joint' and 'img2text' modes.

* Fix tests and remove some testing code.

* Add option to use fixed latents for UniDiffuserPipelineSlowTests and fix issue in modeling_text_decoder.py.

* Improve UniDiffuser docs, particularly the usage examples, and improve slow tests with new expected outputs.

* make style

* Fix examples to load model in float16.

* In image-to-text mode, sample from the autoencoder moment distribution instead of always getting its mode.

* make style

* When encoding the image using the VAE, scale the image latents by the VAE's scaling factor.

* make style

* Clean up code and make slow tests pass.

* make fix-copies

* [docs] Fix docstring (#3334)

fix docstring

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* if dreambooth lora (#3360)

* update IF stage I pipelines

add fixed variance schedulers and lora loading

* added kv lora attn processor

* allow loading into alternative lora attn processor

* make vae optional

* throw away predicted variance

* allow loading into added kv lora layer

* allow load T5

* allow pre compute text embeddings

* set new variance type in schedulers

* fix copies

* refactor all prompt embedding code

class prompts are now included in pre-encoding code
max tokenizer length is now configurable
embedding attention mask is now configurable

* fix for when variance type is not defined on scheduler

* do not pre compute validation prompt if not present

* add example test for if lora dreambooth

* add check for train text encoder and pre compute text embeddings

* Postprocessing refactor all others (#3337)

* add text2img

* fix-copies

* add

* add all other pipelines

* add

* add

* add

* add

* add

* make style

* style + fix copies

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>

* [docs] Improve safetensors docstring (#3368)

* clarify safetensor docstring

* fix typo

* apply feedback

* add: a warning message when using xformers in a PT 2.0 env. (#3365)

* add: a warning message when using xformers in a PT 2.0 env.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* StableDiffusionInpaintingPipeline - resize image w.r.t height and width (#3322)

* StableDiffusionInpaintingPipeline now resizes input images and masks w.r.t to passed input height and width. Default is already set to 512. This addresses the common tensor mismatch error. Also moved type check into relevant funciton to keep main pipeline body tidy.

* Fixed StableDiffusionInpaintingPrepareMaskAndMaskedImageTests

Due to previous commit these tests were failing as height and width need to be passed into the prepare_mask_and_masked_image function, I have updated the code and added a height/width variable per unit test as it seemed more appropriate than the current hard coded solution

* Added a resolution test to StableDiffusionInpaintPipelineSlowTests

this unit test simply gets the input and resizes it into some that would fail (e.g. would throw a tensor mismatch error/not a mult of 8). Then passes it through the pipeline and verifies it produces output with correct dims w.r.t the passed height and width

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make style

* [docs] Adapt a model (#3326)

* first draft

* apply feedback

* conv_in.weight thrown away

* [docs] Load safetensors (#3333)

* safetensors

* apply feedback

* apply feedback

* Apply suggestions from code review

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make style

* [Docs] Fix stable_diffusion.mdx typo (#3398)

Fix typo in last code block. Correct "prommpts" to "prompt"

* Support ControlNet v1.1 shuffle properly (#3340)

* add inferring_controlnet_cond_batch

* Revert "add inferring_controlnet_cond_batch"

This reverts commit abe8d6311d4b7f5b9409ca709c7fabf80d06c1a9.

* set guess_mode to True
whenever global_pool_conditions is True

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* nit

* add integration test

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [Tests] better determinism (#3374)

* enable deterministic pytorch and cuda operations.

* disable manual seeding.

* make style && make quality for unet_2d tests.

* enable determinism for the unet2dconditional model.

* add CUBLAS_WORKSPACE_CONFIG for better reproducibility.

* relax tolerance (very weird issue, though).

* revert to torch manual_seed() where needed.

* relax more tolerance.

* better placement of the cuda variable and relax more tolerance.

* enable determinism for 3d condition model.

* relax tolerance.

* add: determinism to alt_diffusion.

* relax tolerance for alt diffusion.

* dance diffusion.

* dance diffusion is flaky.

* test_dict_tuple_outputs_equivalent edit.

* fix two more tests.

* fix more ddim tests.

* fix: argument.

* change to diff in place of difference.

* fix: test_save_load call.

* test_save_load_float16 call.

* fix: expected_max_diff

* fix: paint by example.

* relax tolerance.

* add determinism to 1d unet model.

* torch 2.0 regressions seem to be brutal

* determinism to vae.

* add reason to skipping.

* up tolerance.

* determinism to vq.

* determinism to cuda.

* determinism to the generic test pipeline file.

* refactor general pipelines testing a bit.

* determinism to alt diffusion i2i

* up tolerance for alt diff i2i and audio diff

* up tolerance.

* determinism to audioldm

* increase tolerance for audioldm lms.

* increase tolerance for paint by paint.

* increase tolerance for repaint.

* determinism to cycle diffusion and sd 1.

* relax tol for cycle diffusion 🚲

* relax tol for sd 1.0

* relax tol for controlnet.

* determinism to img var.

* relax tol for img variation.

* tolerance to i2i sd

* make style

* determinism to inpaint.

* relax tolerance for inpaiting.

* determinism for inpainting legacy

* relax tolerance.

* determinism to instruct pix2pix

* determinism to model editing.

* model editing tolerance.

* panorama determinism

* determinism to pix2pix zero.

* determinism to sag.

* sd 2. determinism

* sd. tolerance

* disallow tf32 matmul.

* relax tolerance is all you need.

* make style and determinism to sd 2 depth

* relax tolerance for depth.

* tolerance to diffedit.

* tolerance to sd 2 inpaint.

* up tolerance.

* determinism in upscaling.

* tolerance in upscaler.

* more tolerance relaxation.

* determinism to v pred.

* up tol for v_pred

* unclip determinism

* determinism to unclip img2img

* determinism to text to video.

* determinism to last set of tests

* up tol.

* vq cumsum doesn't have a deterministic kernel

* relax tol

* relax tol

* [docs] Add transformers to install (#3388)

add transformers to install

* [deepspeed] partial ZeRO-3 support (#3076)

* [deepspeed] partial ZeRO-3 support

* cleanup

* improve deepspeed fixes

* Improve

* make style

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Add omegaconf for tests (#3400)

Add omegaconfg

* Fix various bugs with LoRA Dreambooth and Dreambooth script (#3353)

* Improve checkpointing lora

* fix more

* Improve doc string

* Update src/diffusers/loaders.py

* make stytle

* Apply suggestions from code review

* Update src/diffusers/loaders.py

* Apply suggestions from code review

* Apply suggestions from code review

* better

* Fix all

* Fix multi-GPU dreambooth

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fix all

* make style

* make style

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fix docker file (#3402)

* up

* up

* fix: deepseepd_plugin retrieval from accelerate state (#3410)

* [Docs] Add `sigmoid` beta_scheduler to docstrings of relevant Schedulers (#3399)

* Add `sigmoid` beta scheduler to `DDPMScheduler` docstring

* Add `sigmoid` beta scheduler to `RePaintScheduler` docstring

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Don't install accelerate and transformers from source (#3415)

* Don't install transformers and accelerate from source (#3414)

* Improve fast tests (#3416)

Update pr_tests.yml

* attention refactor: the trilogy  (#3387)

* Replace `AttentionBlock` with `Attention`

* use _from_deprecated_attn_block check re: @patrickvonplaten

* [Docs] update the PT 2.0 optimization doc with latest findings (#3370)

* add: benchmarking stats for A100 and V100.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* address patrick's comments.

* add: rtx 4090 stats

* ⚔ benchmark reports done

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* 3313 pr link.

* add: plots.

Co-authored-by: Pedro <pedro@huggingface.co>

* fix formattimg

* update number percent.

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fix style rendering (#3433)

* Fix style rendering.

* Fix typo

* unCLIP scheduler do not use note (#3417)

* Replace deprecated command with environment file (#3409)

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fix warning message pipeline loading (#3446)

* add stable diffusion tensorrt img2img pipeline (#3419)

* add stable diffusion tensorrt img2img pipeline

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>

* update docstrings

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>

---------

Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>

* Refactor controlnet and add img2img and inpaint (#3386)

* refactor controlnet and add img2img and inpaint

* First draft to get pipelines to work

* make style

* Fix more

* Fix more

* More tests

* Fix more

* Make inpainting work

* make style and more tests

* Apply suggestions from code review

* up

* make style

* Fix imports

* Fix more

* Fix more

* Improve examples

* add test

* Make sure import is correctly deprecated

* Make sure everything works in compile mode

* make sure authorship is correctly attributed

* [Scheduler] DPM-Solver (++) Inverse Scheduler (#3335)

* Add DPM-Solver Multistep Inverse Scheduler

* Add draft tests for DiffEdit

* Add inverse sde-dpmsolver steps to tune image diversity from inverted latents

* Fix tests

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [Docs] Fix incomplete docstring for resnet.py (#3438)

Fix incomplete docstrings for resnet.py

* fix tiled vae blend extent range (#3384)

fix tiled vae bleand extent range

* Small update to "Next steps" section (#3443)

Small update to "Next steps" section:

- PyTorch 2 is recommended.
- Updated improvement figures.

* Allow arbitrary aspect ratio in IFSuperResolutionPipeline (#3298)

* Update pipeline_if_superresolution.py

Allow arbitrary aspect ratio in IFSuperResolutionPipeline by using the input image shape

* IFSuperResolutionPipeline: allow the user to override the height and width through the arguments

* update IFSuperResolutionPipeline width/height doc string to match StableDiffusionInpaintPipeline conventions

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Adding 'strength' parameter to StableDiffusionInpaintingPipeline  (#3424)

* Added explanation of 'strength' parameter

* Added get_timesteps function which relies on new strength parameter

* Added `strength` parameter which defaults to 1.

* Swapped ordering so `noise_timestep` can be calculated before masking the image

this is required when you aren't applying 100% noise to the masked region, e.g. strength < 1.

* Added strength to check_inputs, throws error if out of range

* Changed `prepare_latents` to initialise latents w.r.t strength

inspired from the stable diffusion img2img pipeline, init latents are initialised by converting the init image into a VAE latent and adding noise (based upon the strength parameter passed in), e.g. random when strength = 1, or the init image at strength = 0.

* WIP: Added a unit test for the new strength parameter in the StableDiffusionInpaintingPipeline

still need to add correct regression values

* Created a is_strength_max to initialise from pure random noise

* Updated unit tests w.r.t new strength parameter + fixed new strength unit test

* renamed parameter to avoid confusion with variable of same name

* Updated regression values for new strength test - now passes

* removed 'copied from' comment as this method is now different and divergent from the cpy

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Ensure backwards compatibility for prepare_mask_and_masked_image

created a return_image boolean and initialised to false

* Ensure backwards compatibility for prepare_latents

* Fixed copy check typo

* Fixes w.r.t backward compibility changes

* make style

* keep function argument ordering same for backwards compatibility in callees with copied from statements

* make fix-copies

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: William Berman <WLBberman@gmail.com>

* [WIP] Bugfix - Pipeline.from_pretrained is broken when the pipeline is partially downloaded (#3448)

Added bugfix using f strings.

* Fix gradient checkpointing bugs in freezing part of models (requires_grad=False) (#3404)

* gradient checkpointing bug fix

* bug fix; changes for reviews

* reformat

* reformat

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Make dreambooth lora more robust to orig unet (#3462)

* Make dreambooth lora more robust to orig unet

* up

* Reduce peak VRAM by releasing large attention tensors (as soon as they're unnecessary) (#3463)

Release large tensors in attention (as soon as they're no longer required). Reduces peak VRAM by nearly 2 GB for 1024x1024 (even after slicing), and the savings scale up with image size.

* Add min snr to text2img lora training script (#3459)

add min snr to text2img lora training script

* Add inpaint lora scale support (#3460)

* add inpaint lora scale support

* add inpaint lora scale test

---------

Co-authored-by: yueyang.hyy <yueyang.hyy@alibaba-inc.com>

* [From ckpt] Fix from_ckpt (#3466)

* Correct from_ckpt

* make style

* Update full dreambooth script to work with IF (#3425)

* Add IF dreambooth docs (#3470)

* parameterize pass single args through tuple (#3477)

* attend and excite tests disable determinism on the class level (#3478)

* dreambooth docs torch.compile note (#3471)

* dreambooth docs torch.compile note

* Update examples/dreambooth/README.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update examples/dreambooth/README.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* add: if entry in the dreambooth training docs. (#3472)

* [docs] Textual inversion inference (#3473)

* add textual inversion inference to docs

* add to toctree

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* [docs] Distributed inference (#3376)

* distributed inference

* move to inference section

* apply feedback

* update with split_between_processes

* apply feedback

* [{Up,Down}sample1d] explicit view kernel size as number elements in flattened indices (#3479)

explicit view kernel size as number elements in flattened indices

* mps & onnx tests rework (#3449)

* Remove ONNX tests from PR.

They are already a part of push_tests.yml.

* Remove mps tests from PRs.

They are already performed on push.

* Fix workflow name for fast push tests.

* Extract mps tests to a workflow.

For better control/filtering.

* Remove --extra-index-url from mps tests

* Increase tolerance of mps test

This test passes in my Mac (Ventura 13.3) but fails in the CI hardware
(Ventura 13.2). I ran the local tests following the same steps that
exist in the CI workflow.

* Temporarily run mps tests on pr

So we can test.

* Revert "Temporarily run mps tests on pr"

Tests passed, go back to running on push.

* [Attention processor] Better warning message when shifting to `AttnProcessor2_0` (#3457)

* add: debugging to enabling memory efficient processing

* add: better warning message.

* [Docs] add note on local directory path. (#3397)

add note on local directory path.

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Refactor full determinism (#3485)

* up

* fix more

* Apply suggestions from code review

* fix more

* fix more

* Check it

* Remove 16:8

* fix more

* fix more

* fix more

* up

* up

* Test only stable diffusion

* Test only two files

* up

* Try out spinning up processes that can be killed

* up

* Apply suggestions from code review

* up

* up

* Fix DPM single (#3413)

* Fix DPM single

* add test

* fix one more bug

* Apply suggestions from code review

Co-authored-by: StAlKeR7779 <stalkek7779@yandex.ru>

---------

Co-authored-by: StAlKeR7779 <stalkek7779@yandex.ru>

* Add `use_Karras_sigmas` to DPMSolverSinglestepScheduler (#3476)

* add use_karras_sigmas

* add karras test

* add doc

* Adds local_files_only bool to prevent forced online connection (#3486)

* make style

* [Docs] Korean translation (optimization, training) (#3488)

* feat) optimization kr translation

* fix) typo, italic setting

* feat) dreambooth, text2image kr

* feat) lora kr

* fix) LoRA

* fix) fp16 fix

* fix) doc-builder style

* fix) fp16 일부 단어 수정

* fix) fp16 style fix

* fix) opt, training docs update

* feat) toctree update

* feat) toctree update

---------

Co-authored-by: Chanran Kim <seriousran@gmail.com>

* DataLoader respecting EXIF data in Training Images (#3465)

* DataLoader will now bake in any transforms or image manipulations contained in the EXIF

Images may have rotations stored in EXIF. Training using such images will cause those transforms to be ignored while training and thus produce unexpected results

* Fixed the Dataloading EXIF issue in main DreamBooth training as well

* Run make style (black & isort)

* make style

* feat: allow disk offload for diffuser models (#3285)

* allow disk offload for diffuser models

* sort import

* add max_memory argument

* Changed sample[0] to images[0] (#3304)

A pipeline object stores the results in `images` not in `sample`.
Current code blocks don't work.

* Typo in tutorial (#3295)

* Torch compile graph fix (#3286)

* fix more

* Fix more

* fix more

* Apply suggestions from code review

* fix

* make style

* make fix-copies

* fix

* make sure torch compile

* Clean

* fix test

* Postprocessing refactor img2img (#3268)

* refactor img2img VaeImageProcessor.postprocess

* remove copy from for init, run_safety_checker, decode_latents

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

---------

Co-authored-by: yiyixuxu <yixu@yis-macbook-pro.lan>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* [Torch 2.0 compile] Fix more torch compile breaks (#3313)

* Fix more torch compile breaks

* add tests

* Fix all

* fix controlnet

* fix more

* Add Horace He as co-author.
>
>
Co-authored-by: Horace He <horacehe2007@yahoo.com>

* Add Horace He as co-author.

Co-authored-by: Horace He <horacehe2007@yahoo.com>

---------

Co-authored-by: Horace He <horacehe2007@yahoo.com>

* fix: scale_lr and sync example readme and docs. (#3299)

* fix: scale_lr and sync example readme and docs.

* fix doc link.

* Update stable_diffusion.mdx (#3310)

fixed import statement

* Fix missing variable assign in DeepFloyd-IF-II (#3315)

Fix missing variable assign

lol

* Correct doc build for patch releases (#3316)

Update build_documentation.yml

* Add Stable Diffusion RePaint to community pipelines (#3320)

* Add Stable Diffsuion RePaint to community pipelines

- Adds Stable Diffsuion RePaint to community pipelines
- Add Readme enty for pipeline

* Fix: Remove wrong import

- Remove wrong import
- Minor change in comments

* Fix: Code formatting of stable_diffusion_repaint

* Fix: ruff errors in stable_diffusion_repaint

* Fix multistep dpmsolver for cosine schedule (suitable for deepfloyd-if) (#3314)

* fix multistep dpmsolver for cosine schedule (deepfloy-if)

* fix a typo

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* update all dpmsolver (singlestep, multistep, dpm, dpm++) for cosine noise schedule

* add test, fix style

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [docs] Improve LoRA docs (#3311)

* update docs

* add to toctree

* apply feedback

* Added input pretubation (#3292)

* Added input pretubation

* Fixed spelling

* Update write_own_pipeline.mdx (#3323)

* update controlling generation doc with latest goodies. (#3321)

* [Quality] Make style (#3341)

* Fix config dpm (#3343)

* Add the SDE variant of DPM-Solver and DPM-Solver++ (#3344)

* add SDE variant of DPM-Solver and DPM-Solver++

* add test

* fix typo

* fix typo

* Add upsample_size to AttnUpBlock2D, AttnDownBlock2D (#3275)

The argument `upsample_size` needs to be added to these modules to allow compatibility with other blocks that require this argument.

* Rename --only_save_embeds to --save_as_full_pipeline (#3206)

* Set --only_save_embeds to False by default

Due to how the option is named, it makes more sense to behave like this.

* Refactor only_save_embeds to save_as_full_pipeline

* [AudioLDM] Generalise conversion script (#3328)

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Fix TypeError when using prompt_embeds and negative_prompt (#2982)

* test: Added test case

* fix: fixed type checking issue on _encode_prompt

* fix: fixed copies consistency

* fix: one copy was not sufficient

* Fix pipeline class on README (#3345)

Update README.md

* Inpainting: typo in docs (#3331)

Typo in docs

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Add `use_Karras_sigmas` to LMSDiscreteScheduler (#3351)

* add karras sigma to lms discrete scheduler

* add test for lms_scheduler karras

* reformat test lms

* Batched load of textual inversions (#3277)

* Batched load of textual inversions

- Only call resize_token_embeddings once per batch as it is the most expensive operation
- Allow pretrained_model_name_or_path and token to be an optional list
- Remove Dict from type annotation pretrained_model_name_or_path as it was not supported in this function
- Add comment that single files (e.g. .pt/.safetensors) are supported
- Add comment for token parameter
- Convert token override log message from warning to info

* Update src/diffusers/loaders.py

Check for duplicate tokens

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update condition for None tokens

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make fix-copies

* [docs] Fix docstring (#3334)

fix docstring

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* if dreambooth lora (#3360)

* update IF stage I pipelines

add fixed variance schedulers and lora loading

* added kv lora attn processor

* allow loading into alternative lora attn processor

* make vae optional

* throw away predicted variance

* allow loading into added kv lora layer

* allow load T5

* allow pre compute text embeddings

* set new variance type in schedulers

* fix copies

* refactor all prompt embedding code

class prompts are now included in pre-encoding code
max tokenizer length is now configurable
embedding attention mask is now configurable

* fix for when variance type is not defined on scheduler

* do not pre compute validation prompt if not present

* add example test for if lora dreambooth

* add check for train text encoder and pre compute text embeddings

* Postprocessing refactor all others (#3337)

* add text2img

* fix-copies

* add

* add all other pipelines

* add

* add

* add

* add

* add

* make style

* style + fix copies

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>

* [docs] Improve safetensors docstring (#3368)

* clarify safetensor docstring

* fix typo

* apply feedback

* add: a warning message when using xformers in a PT 2.0 env. (#3365)

* add: a warning message when using xformers in a PT 2.0 env.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* StableDiffusionInpaintingPipeline - resize image w.r.t height and width (#3322)

* StableDiffusionInpaintingPipeline now resizes input images and masks w.r.t to passed input height and width. Default is already set to 512. This addresses the common tensor mismatch error. Also moved type check into relevant funciton to keep main pipeline body tidy.

* Fixed StableDiffusionInpaintingPrepareMaskAndMaskedImageTests

Due to previous commit these tests were failing as height and width need to be passed into the prepare_mask_and_masked_image function, I have updated the code and added a height/width variable per unit test as it seemed more appropriate than the current hard coded solution

* Added a resolution test to StableDiffusionInpaintPipelineSlowTests

this unit test simply gets the input and resizes it into some that would fail (e.g. would throw a tensor mismatch error/not a mult of 8). Then passes it through the pipeline and verifies it produces output with correct dims w.r.t the passed height and width

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make style

* [docs] Adapt a model (#3326)

* first draft

* apply feedback

* conv_in.weight thrown away

* [docs] Load safetensors (#3333)

* safetensors

* apply feedback

* apply feedback

* Apply suggestions from code review

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make style

* [Docs] Fix stable_diffusion.mdx typo (#3398)

Fix typo in last code block. Correct "prommpts" to "prompt"

* Support ControlNet v1.1 shuffle properly (#3340)

* add inferring_controlnet_cond_batch

* Revert "add inferring_controlnet_cond_batch"

This reverts commit abe8d6311d4b7f5b9409ca709c7fabf80d06c1a9.

* set guess_mode to True
whenever global_pool_conditions is True

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* nit

* add integration test

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [Tests] better determinism (#3374)

* enable deterministic pytorch and cuda operations.

* disable manual seeding.

* make style && make quality for unet_2d tests.

* enable determinism for the unet2dconditional model.

* add CUBLAS_WORKSPACE_CONFIG for better reproducibility.

* relax tolerance (very weird issue, though).

* revert to torch manual_seed() where needed.

* relax more tolerance.

* better placement of the cuda variable and relax more tolerance.

* enable determinism for 3d condition model.

* relax tolerance.

* add: determinism to alt_diffusion.

* relax tolerance for alt diffusion.

* dance diffusion.

* dance diffusion is flaky.

* test_dict_tuple_outputs_equivalent edit.

* fix two more tests.

* fix more ddim tests.

* fix: argument.

* change to diff in place of difference.

* fix: test_save_load call.

* test_save_load_float16 call.

* fix: expected_max_diff

* fix: paint by example.

* relax tolerance.

* add determinism to 1d unet model.

* torch 2.0 regressions seem to be brutal

* determinism to vae.

* add reason to skipping.

* up tolerance.

* determinism to vq.

* determinism to cuda.

* determinism to the generic test pipeline file.

* refactor general pipelines testing a bit.

* determinism to alt diffusion i2i

* up tolerance for alt diff i2i and audio diff

* up tolerance.

* determinism to audioldm

* increase tolerance for audioldm lms.

* increase tolerance for paint by paint.

* increase tolerance for repaint.

* determinism to cycle diffusion and sd 1.

* relax tol for cycle diffusion 🚲

* relax tol for sd 1.0

* relax tol for controlnet.

* determinism to img var.

* relax tol for img variation.

* tolerance to i2i sd

* make style

* determinism to inpaint.

* relax tolerance for inpaiting.

* determinism for inpainting legacy

* relax tolerance.

* determinism to instruct pix2pix

* determinism to model editing.

* model editing tolerance.

* panorama determinism

* determinism to pix2pix zero.

* determinism to sag.

* sd 2. determinism

* sd. tolerance

* disallow tf32 matmul.

* relax tolerance is all you need.

* make style and determinism to sd 2 depth

* relax tolerance for depth.

* tolerance to diffedit.

* tolerance to sd 2 inpaint.

* up tolerance.

* determinism in upscaling.

* tolerance in upscaler.

* more tolerance relaxation.

* determinism to v pred.

* up tol for v_pred

* unclip determinism

* determinism to unclip img2img

* determinism to text to video.

* determinism to last set of tests

* up tol.

* vq cumsum doesn't have a deterministic kernel

* relax tol

* relax tol

* [docs] Add transformers to install (#3388)

add transformers to install

* [deepspeed] partial ZeRO-3 support (#3076)

* [deepspeed] partial ZeRO-3 support

* cleanup

* improve deepspeed fixes

* Improve

* make style

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Add omegaconf for tests (#3400)

Add omegaconfg

* Fix various bugs with LoRA Dreambooth and Dreambooth script (#3353)

* Improve checkpointing lora

* fix more

* Improve doc string

* Update src/diffusers/loaders.py

* make stytle

* Apply suggestions from code review

* Update src/diffusers/loaders.py

* Apply suggestions from code review

* Apply suggestions from code review

* better

* Fix all

* Fix multi-GPU dreambooth

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fix all

* make style

* make style

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fix docker file (#3402)

* up

* up

* fix: deepseepd_plugin retrieval from accelerate state (#3410)

* [Docs] Add `sigmoid` beta_scheduler to docstrings of relevant Schedulers (#3399)

* Add `sigmoid` beta scheduler to `DDPMScheduler` docstring

* Add `sigmoid` beta scheduler to `RePaintScheduler` docstring

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Don't install accelerate and transformers from source (#3415)

* Don't install transformers and accelerate from source (#3414)

* Improve fast tests (#3416)

Update pr_tests.yml

* attention refactor: the trilogy  (#3387)

* Replace `AttentionBlock` with `Attention`

* use _from_deprecated_attn_block check re: @patrickvonplaten

* [Docs] update the PT 2.0 optimization doc with latest findings (#3370)

* add: benchmarking stats for A100 and V100.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* address patrick's comments.

* add: rtx 4090 stats

* ⚔ benchmark reports done

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* 3313 pr link.

* add: plots.

Co-authored-by: Pedro <pedro@huggingface.co>

* fix formattimg

* update number percent.

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Fix style rendering (#3433)

* Fix style rendering.

* Fix typo

* unCLIP scheduler do not use note (#3417)

* Replace deprecated command with environment file (#3409)

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* …
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet