diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml
index 483bbd1650da..e9e672e72dd5 100644
--- a/.github/ISSUE_TEMPLATE/bug-report.yml
+++ b/.github/ISSUE_TEMPLATE/bug-report.yml
@@ -13,7 +13,7 @@ body:
*Give your issue a fitting title. Assume that someone which very limited knowledge of diffusers can understand your issue. Add links to the source code, documentation other issues, pull requests etc...*
- 2. If your issue is about something not working, **always** provide a reproducible code snippet. The reader should be able to reproduce your issue by **only copy-pasting your code snippet into a Python shell**.
*The community cannot solve your issue if it cannot reproduce it. If your bug is related to training, add your training script and make everything needed to train public. Otherwise, just add a simple Python code snippet.*
- - 3. Add the **minimum amount of code / context that is needed to understand, reproduce your issue**.
+ - 3. Add the **minimum** amount of code / context that is needed to understand, reproduce your issue.
*Make the life of maintainers easy. `diffusers` is getting many issues every day. Make sure your issue is about one bug and one bug only. Make sure you add only the context, code needed to understand your issues - nothing more. Generally, every issue is a way of documenting this library, try to make it a good documentation entry.*
- 4. For issues related to community pipelines (i.e., the pipelines located in the `examples/community` folder), please tag the author of the pipeline in your issue thread as those pipelines are not maintained.
- type: markdown
@@ -61,21 +61,46 @@ body:
All issues are read by one of the core maintainers, so if you don't know who to tag, just leave this blank and
a core maintainer will ping the right person.
- Please tag fewer than 3 people.
-
- General library related questions: @patrickvonplaten and @sayakpaul
+ Please tag a maximum of 2 people.
+
+ Questions on DiffusionPipeline (Saving, Loading, From pretrained, ...):
+
+ Questions on pipelines:
+ - Stable Diffusion @yiyixuxu @DN6 @patrickvonplaten @sayakpaul @patrickvonplaten
+ - Stable Diffusion XL @yiyixuxu @sayakpaul @DN6 @patrickvonplaten
+ - Kandinsky @yiyixuxu @patrickvonplaten
+ - ControlNet @sayakpaul @yiyixuxu @DN6 @patrickvonplaten
+ - T2I Adapter @sayakpaul @yiyixuxu @DN6 @patrickvonplaten
+ - IF @DN6 @patrickvonplaten
+ - Text-to-Video / Video-to-Video @DN6 @sayakpaul @patrickvonplaten
+ - Wuerstchen @DN6 @patrickvonplaten
+ - Other: @yiyixuxu @DN6
+
+ Questions on models:
+ - UNet @DN6 @yiyixuxu @sayakpaul @patrickvonplaten
+ - VAE @sayakpaul @DN6 @yiyixuxu @patrickvonplaten
+ - Transformers/Attention @DN6 @yiyixuxu @sayakpaul @DN6 @patrickvonplaten
- Questions on the training examples: @williamberman, @sayakpaul, @yiyixuxu
+ Questions on Schedulers: @yiyixuxu @patrickvonplaten
- Questions on memory optimizations, LoRA, float16, etc.: @williamberman, @patrickvonplaten, and @sayakpaul
+ Questions on LoRA: @sayakpaul @patrickvonplaten
- Questions on schedulers: @patrickvonplaten and @williamberman
+ Questions on Textual Inversion: @sayakpaul @patrickvonplaten
- Questions on models and pipelines: @patrickvonplaten, @sayakpaul, and @williamberman (for community pipelines, please tag the original author of the pipeline)
+ Questions on Training:
+ - DreamBooth @sayakpaul @patrickvonplaten
+ - Text-to-Image Fine-tuning @sayakpaul @patrickvonplaten
+ - Textual Inversion @sayakpaul @patrickvonplaten
+ - ControlNet @sayakpaul @patrickvonplaten
+
+ Questions on Tests: @DN6 @sayakpaul @yiyixuxu
+
+ Questions on Documentation: @stevhliu
Questions on JAX- and MPS-related things: @pcuenca
- Questions on audio pipelines: @patrickvonplaten, @kashif, and @sanchit-gandhi
+ Questions on audio pipelines: @DN6 @patrickvonplaten
+
+
- Documentation: @stevhliu and @yiyixuxu
placeholder: "@Username ..."
diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml
index c31e179c7628..a15a5412c4e4 100644
--- a/.github/workflows/push_tests.yml
+++ b/.github/workflows/push_tests.yml
@@ -1,10 +1,11 @@
-name: Slow tests on main
+name: Slow Tests on main
on:
push:
branches:
- main
+
env:
DIFFUSERS_IS_CI: yes
HF_HOME: /mnt/cache
@@ -12,53 +13,115 @@ env:
MKL_NUM_THREADS: 8
PYTEST_TIMEOUT: 600
RUN_SLOW: yes
+ PIPELINE_USAGE_CUTOFF: 50000
jobs:
- run_slow_tests:
+ setup_torch_cuda_pipeline_matrix:
+ name: Setup Torch Pipelines CUDA Slow Tests Matrix
+ runs-on: docker-gpu
+ container:
+ image: diffusers/diffusers-pytorch-cpu # this is a CPU image, but we need it to fetch the matrix
+ options: --shm-size "16gb" --ipc host
+ outputs:
+ pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+ - name: Install dependencies
+ run: |
+ apt-get update && apt-get install libsndfile1-dev libgl1 -y
+ python -m pip install -e .[quality,test]
+ python -m pip install git+https://github.com/huggingface/accelerate.git
+
+ - name: Environment
+ run: |
+ python utils/print_env.py
+
+ - name: Fetch Pipeline Matrix
+ id: fetch_pipeline_matrix
+ run: |
+ matrix=$(python utils/fetch_torch_cuda_pipeline_test_matrix.py)
+ echo $matrix
+ echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
+
+ - name: Pipeline Tests Artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v2
+ with:
+ name: test-pipelines.json
+ path: reports
+
+ torch_pipelines_cuda_tests:
+ name: Torch Pipelines CUDA Slow Tests
+ needs: setup_torch_cuda_pipeline_matrix
strategy:
fail-fast: false
max-parallel: 1
matrix:
- config:
- - name: Slow PyTorch CUDA tests on Ubuntu
- framework: pytorch
- runner: docker-gpu
- image: diffusers/diffusers-pytorch-cuda
- report: torch_cuda
- - name: Slow Flax TPU tests on Ubuntu
- framework: flax
- runner: docker-tpu
- image: diffusers/diffusers-flax-tpu
- report: flax_tpu
- - name: Slow ONNXRuntime CUDA tests on Ubuntu
- framework: onnxruntime
- runner: docker-gpu
- image: diffusers/diffusers-onnxruntime-cuda
- report: onnx_cuda
-
- name: ${{ matrix.config.name }}
-
- runs-on: ${{ matrix.config.runner }}
-
+ module: ${{ fromJson(needs.setup_torch_cuda_pipeline_matrix.outputs.pipeline_test_matrix) }}
+ runs-on: docker-gpu
container:
- image: ${{ matrix.config.image }}
- options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ ${{ matrix.config.runner == 'docker-tpu' && '--privileged' || '--gpus 0'}}
-
+ image: diffusers/diffusers-pytorch-cuda
+ options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --gpus 0
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+ - name: NVIDIA-SMI
+ run: |
+ nvidia-smi
+ - name: Install dependencies
+ run: |
+ apt-get update && apt-get install libsndfile1-dev libgl1 -y
+ python -m pip install -e .[quality,test]
+ python -m pip install git+https://github.com/huggingface/accelerate.git
+ - name: Environment
+ run: |
+ python utils/print_env.py
+ - name: Slow PyTorch CUDA checkpoint tests on Ubuntu
+ env:
+ HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
+ # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
+ CUBLAS_WORKSPACE_CONFIG: :16:8
+ run: |
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "not Flax and not Onnx" \
+ --make-reports=tests_pipeline_${{ matrix.module }}_cuda \
+ tests/pipelines/${{ matrix.module }}
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_pipeline_${{ matrix.module }}_cuda_stats.txt
+ cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v2
+ with:
+ name: pipeline_${{ matrix.module }}_test_reports
+ path: reports
+
+ torch_cuda_tests:
+ name: Torch CUDA Tests
+ runs-on: docker-gpu
+ container:
+ image: diffusers/diffusers-pytorch-cuda
+ options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --gpus 0
defaults:
run:
shell: bash
-
+ strategy:
+ matrix:
+ module: [models, schedulers, lora, others]
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- - name: NVIDIA-SMI
- if : ${{ matrix.config.runner == 'docker-gpu' }}
- run: |
- nvidia-smi
-
- name: Install dependencies
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
@@ -70,47 +133,121 @@ jobs:
python utils/print_env.py
- name: Run slow PyTorch CUDA tests
- if: ${{ matrix.config.framework == 'pytorch' }}
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
-
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx and not compile" \
- --make-reports=tests_${{ matrix.config.report }} \
- tests/
+ -s -v -k "not Flax and not Onnx" \
+ --make-reports=tests_torch_cuda \
+ tests/${{ matrix.module }}
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_torch_cuda_stats.txt
+ cat reports/tests_torch_cuda_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v2
+ with:
+ name: torch_cuda_test_reports
+ path: reports
+
+ flax_tpu_tests:
+ name: Flax TPU Tests
+ runs-on: docker-tpu
+ container:
+ image: diffusers/diffusers-flax-tpu
+ options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --privileged
+ defaults:
+ run:
+ shell: bash
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ apt-get update && apt-get install libsndfile1-dev libgl1 -y
+ python -m pip install -e .[quality,test]
+ python -m pip install git+https://github.com/huggingface/accelerate.git
+
+ - name: Environment
+ run: |
+ python utils/print_env.py
- name: Run slow Flax TPU tests
- if: ${{ matrix.config.framework == 'flax' }}
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: |
python -m pytest -n 0 \
-s -v -k "Flax" \
- --make-reports=tests_${{ matrix.config.report }} \
+ --make-reports=tests_flax_tpu \
tests/
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_flax_tpu_stats.txt
+ cat reports/tests_flax_tpu_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v2
+ with:
+ name: flax_tpu_test_reports
+ path: reports
+
+ onnx_cuda_tests:
+ name: ONNX CUDA Tests
+ runs-on: docker-gpu
+ container:
+ image: diffusers/diffusers-onnxruntime-cuda
+ options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --gpus 0
+ defaults:
+ run:
+ shell: bash
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ apt-get update && apt-get install libsndfile1-dev libgl1 -y
+ python -m pip install -e .[quality,test]
+ python -m pip install git+https://github.com/huggingface/accelerate.git
+
+ - name: Environment
+ run: |
+ python utils/print_env.py
+
- name: Run slow ONNXRuntime CUDA tests
- if: ${{ matrix.config.framework == 'onnxruntime' }}
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
- --make-reports=tests_${{ matrix.config.report }} \
+ --make-reports=tests_onnx_cuda \
tests/
- name: Failure short reports
if: ${{ failure() }}
- run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
+ run: |
+ cat reports/tests_onnx_cuda_stats.txt
+ cat reports/tests_onnx_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
- name: ${{ matrix.config.report }}_test_reports
+ name: onnx_cuda_test_reports
path: reports
run_torch_compile_tests:
@@ -131,21 +268,17 @@ jobs:
- name: NVIDIA-SMI
run: |
nvidia-smi
-
- name: Install dependencies
run: |
python -m pip install -e .[quality,test,training]
-
- name: Environment
run: |
python utils/print_env.py
-
- name: Run example tests on GPU
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
-
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -192,11 +325,13 @@ jobs:
- name: Failure short reports
if: ${{ failure() }}
- run: cat reports/examples_torch_cuda_failures_short.txt
+ run: |
+ cat reports/examples_torch_cuda_stats.txt
+ cat reports/examples_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: examples_test_reports
- path: reports
+ path: reports
\ No newline at end of file
diff --git a/docker/diffusers-pytorch-compile-cuda/Dockerfile b/docker/diffusers-pytorch-compile-cuda/Dockerfile
index b0646084964e..a41be50f9d58 100644
--- a/docker/diffusers-pytorch-compile-cuda/Dockerfile
+++ b/docker/diffusers-pytorch-compile-cuda/Dockerfile
@@ -14,22 +14,23 @@ RUN apt update && \
libsndfile1-dev \
libgl1 \
python3.9 \
+ python3.9-dev \
python3-pip \
python3.9-venv && \
rm -rf /var/lib/apt/lists
# make sure to use venv
-RUN python3 -m venv /opt/venv
+RUN python3.9 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
-RUN python3 -m pip install --no-cache-dir --upgrade pip && \
- python3 -m pip install --no-cache-dir \
+RUN python3.9 -m pip install --no-cache-dir --upgrade pip && \
+ python3.9 -m pip install --no-cache-dir \
torch \
torchvision \
torchaudio \
invisible_watermark && \
- python3 -m pip install --no-cache-dir \
+ python3.9 -m pip install --no-cache-dir \
accelerate \
datasets \
hf-doc-builder \
diff --git a/docker/diffusers-pytorch-cuda/Dockerfile b/docker/diffusers-pytorch-cuda/Dockerfile
index fab3b7082765..4c447749da7b 100644
--- a/docker/diffusers-pytorch-cuda/Dockerfile
+++ b/docker/diffusers-pytorch-cuda/Dockerfile
@@ -25,8 +25,8 @@ ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
python3 -m pip install --no-cache-dir \
- torch \
- torchvision \
+ torch==2.0.1 \
+ torchvision==0.15.2 \
torchaudio \
invisible_watermark && \
python3 -m pip install --no-cache-dir \
diff --git a/docs/README.md b/docs/README.md
index e6408dc976fd..fd0a3a58b0aa 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -128,7 +128,7 @@ When adding a new pipeline:
- Possible an end-to-end example of how to use it
- Add all the pipeline classes that should be linked in the diffusion model. These classes should be added using our Markdown syntax. By default as follows:
-```
+```py
## XXXPipeline
[[autodoc]] XXXPipeline
@@ -138,7 +138,7 @@ When adding a new pipeline:
This will include every public method of the pipeline that is documented, as well as the `__call__` method that is not documented by default. If you just want to add additional methods that are not documented, you can put the list of all methods to add in a list that contains `all`.
-```
+```py
[[autodoc]] XXXPipeline
- all
- __call__
@@ -172,7 +172,7 @@ Arguments should be defined with the `Args:` (or `Arguments:` or `Parameters:`)
an indentation. The argument should be followed by its type, with its shape if it is a tensor, a colon, and its
description:
-```
+```py
Args:
n_layers (`int`): The number of layers of the model.
```
@@ -182,7 +182,7 @@ after the argument.
Here's an example showcasing everything so far:
-```
+```py
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
@@ -196,13 +196,13 @@ Here's an example showcasing everything so far:
For optional arguments or arguments with defaults we follow the following syntax: imagine we have a function with the
following signature:
-```
+```py
def my_function(x: str = None, a: float = 1):
```
then its documentation should look like this:
-```
+```py
Args:
x (`str`, *optional*):
This argument controls ...
@@ -235,14 +235,14 @@ building the return.
Here's an example of a single value return:
-```
+```py
Returns:
`List[int]`: A list of integers in the range [0, 1] --- 1 for a special token, 0 for a sequence token.
```
Here's an example of a tuple return, comprising several objects:
-```
+```py
Returns:
`tuple(torch.FloatTensor)` comprising various elements depending on the configuration ([`BertConfig`]) and inputs:
- ** loss** (*optional*, returned when `masked_lm_labels` is provided) `torch.FloatTensor` of shape `(1,)` --
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index cc50a956439c..b8aa71dacbe2 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -58,6 +58,8 @@
title: Control image brightness
- local: using-diffusers/weighted_prompts
title: Prompt weighting
+ - local: using-diffusers/freeu
+ title: Improve generation quality with FreeU
title: Techniques
- sections:
- local: using-diffusers/pipeline_overview
@@ -104,6 +106,8 @@
title: Custom Diffusion
- local: training/t2i_adapters
title: T2I-Adapters
+ - local: training/ddpo
+ title: Reinforcement learning training with DDPO
title: Training
- sections:
- local: using-diffusers/other-modalities
diff --git a/docs/source/en/api/pipelines/stable_diffusion/adapter.md b/docs/source/en/api/pipelines/stable_diffusion/adapter.md
index 4c7415ddb02b..cf3aca4bfa52 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/adapter.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/adapter.md
@@ -28,8 +28,8 @@ This model was contributed by the community contributor [HimariO](https://github
| Pipeline | Tasks | Demo
|---|---|:---:|
-| [StableDiffusionAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning* | -
-| [StableDiffusionXLAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_xl_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning on StableDiffusion-XL* | -
+| [StableDiffusionAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning* | -
+| [StableDiffusionXLAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning on StableDiffusion-XL* | -
## Usage example with the base model of StableDiffusion-1.4/1.5
diff --git a/docs/source/en/training/ddpo.md b/docs/source/en/training/ddpo.md
new file mode 100644
index 000000000000..1ec961dfdd04
--- /dev/null
+++ b/docs/source/en/training/ddpo.md
@@ -0,0 +1,17 @@
+
+
+# Reinforcement learning training with DDPO
+
+You can fine-tune Stable Diffusion on a reward function via reinforcement learning with the 🤗 TRL library and 🤗 Diffusers. This is done with the Denoising Diffusion Policy Optimization (DDPO) algorithm introduced by Black et al. in [Training Diffusion Models with Reinforcement Learning](https://arxiv.org/abs/2305.13301), which is implemented in 🤗 TRL with the [`~trl.DDPOTrainer`].
+
+For more information, check out the [`~trl.DDPOTrainer`] API reference and the [Finetune Stable Diffusion Models with DDPO via TRL](https://huggingface.co/blog/trl-ddpo) blog post.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/freeu.md b/docs/source/en/using-diffusers/freeu.md
new file mode 100644
index 000000000000..6c23ec754382
--- /dev/null
+++ b/docs/source/en/using-diffusers/freeu.md
@@ -0,0 +1,123 @@
+# Improve generation quality with FreeU
+
+[[open-in-colab]]
+
+The UNet is responsible for denoising during the reverse diffusion process, and there are two distinct features in its architecture:
+
+1. Backbone features primarily contribute to the denoising process
+2. Skip features mainly introduce high-frequency features into the decoder module and can make the network overlook the semantics in the backbone features
+
+However, the skip connection can sometimes introduce unnatural image details. [FreeU](https://hf.co/papers/2309.11497) is a technique for improving image quality by rebalancing the contributions from the UNet’s skip connections and backbone feature maps.
+
+FreeU is applied during inference and it does not require any additional training. The technique works for different tasks such as text-to-image, image-to-image, and text-to-video.
+
+In this guide, you will apply FreeU to the [`StableDiffusionPipeline`], [`StableDiffusionXLPipeline`], and [`TextToVideoSDPipeline`].
+
+## StableDiffusionPipeline
+
+Load the pipeline:
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None
+).to("cuda")
+```
+
+Then enable the FreeU mechanism with the FreeU-specific hyperparameters. These values are scaling factors for the backbone and skip features.
+
+```py
+pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
+```
+
+The values above are from the official FreeU [code repository](https://github.com/ChenyangSi/FreeU) where you can also find [reference hyperparameters](https://github.com/ChenyangSi/FreeU#range-for-more-parameters) for different models.
+
+
+
+Disable the FreeU mechanism by calling `disable_freeu()` on a pipeline.
+
+
+
+And then run inference:
+
+```py
+prompt = "A squirrel eating a burger"
+seed = 2023
+image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0]
+```
+
+The figure below compares non-FreeU and FreeU results respectively for the same hyperparameters used above (`prompt` and `seed`):
+
+
+
+
+Let's see how Stable Diffusion 2 results are impacted:
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16, safety_checker=None
+).to("cuda")
+
+prompt = "A squirrel eating a burger"
+seed = 2023
+
+pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
+image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0]
+```
+
+
+
+
+## Stable Diffusion XL
+
+Finally, let's take a look at how FreeU affects Stable Diffusion XL results:
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16,
+).to("cuda")
+
+prompt = "A squirrel eating a burger"
+seed = 2023
+
+# Comes from
+# https://wandb.ai/nasirk24/UNET-FreeU-SDXL/reports/FreeU-SDXL-Optimal-Parameters--Vmlldzo1NDg4NTUw
+pipeline.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
+image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0]
+```
+
+
+
+
+## Text-to-video generation
+
+FreeU can also be used to improve video quality:
+
+```python
+from diffusers import DiffusionPipeline
+from diffusers.utils import export_to_video
+import torch
+
+model_id = "cerspense/zeroscope_v2_576w"
+pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16).to("cuda")
+pipe = pipe.to("cuda")
+
+prompt = "an astronaut riding a horse on mars"
+seed = 2023
+
+# The values come from
+# https://github.com/lyn-rgb/FreeU_Diffusers#video-pipelines
+pipe.enable_freeu(b1=1.2, b2=1.4, s1=0.9, s2=0.2)
+video_frames = pipe(prompt, height=320, width=576, num_frames=30, generator=torch.manual_seed(seed)).frames
+export_to_video(video_frames, "astronaut_rides_horse.mp4")
+```
+
+Thanks to [kadirnar](https://github.com/kadirnar/) for helping to integrate the feature, and to [justindujardin](https://github.com/justindujardin) for the helpful discussions.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/img2img.md b/docs/source/en/using-diffusers/img2img.md
index 82aa328d2b9c..c0bf4dc52153 100644
--- a/docs/source/en/using-diffusers/img2img.md
+++ b/docs/source/en/using-diffusers/img2img.md
@@ -33,7 +33,7 @@ pipeline.enable_xformers_memory_efficient_attention()
-You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, then you don't need to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](/optimization/torch2.0#scaled-dot-product-attention).
+You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, then you don't need to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention).
@@ -590,17 +590,17 @@ image
## Optimize
-Running diffusion models is computationally expensive and intensive, but with a few optimization tricks, it is entirely possible to run them on consumer and free-tier GPUs. For example, you can use a more memory-efficient form of attention such as PyTorch 2.0's [scaled-dot product attention](optimization/torch2.0#scaled-dot-product-attention) or [xFormers](optimization/xformers) (you can use one or the other, but there's no need to use both). You can also offload the model to the GPU while the other pipeline components wait on the CPU.
+Running diffusion models is computationally expensive and intensive, but with a few optimization tricks, it is entirely possible to run them on consumer and free-tier GPUs. For example, you can use a more memory-efficient form of attention such as PyTorch 2.0's [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention) or [xFormers](../optimization/xformers) (you can use one or the other, but there's no need to use both). You can also offload the model to the GPU while the other pipeline components wait on the CPU.
```diff
+ pipeline.enable_model_cpu_offload()
+ pipeline.enable_xformers_memory_efficient_attention()
```
-With [`torch.compile`](optimization/torch2.0#torch.compile), you can boost your inference speed even more by wrapping your UNet with it:
+With [`torch.compile`](../optimization/torch2.0#torch.compile), you can boost your inference speed even more by wrapping your UNet with it:
```py
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
```
-To learn more, take a look at the [Reduce memory usage](optimization/memory) and [Torch 2.0](optimization/torch2.0) guides.
+To learn more, take a look at the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides.
diff --git a/docs/source/en/using-diffusers/inpaint.md b/docs/source/en/using-diffusers/inpaint.md
index 7f10e43243a3..4d99fca26eb6 100644
--- a/docs/source/en/using-diffusers/inpaint.md
+++ b/docs/source/en/using-diffusers/inpaint.md
@@ -10,87 +10,289 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Text-guided image-inpainting
+# Inpainting
[[open-in-colab]]
-The [`StableDiffusionInpaintPipeline`] allows you to edit specific parts of an image by providing a mask and a text prompt. It uses a version of Stable Diffusion, like [`runwayml/stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting) specifically trained for inpainting tasks.
+Inpainting replaces or edits specific areas of an image. This makes it a useful tool for image restoration like removing defects and artifacts, or even replacing an image area with something entirely new. Inpainting relies on a mask to determine which regions of an image to fill in; the area to inpaint is represented by white pixels and the area to keep is represented by black pixels. The white pixels are filled in by the prompt.
-Get started by loading an instance of the [`StableDiffusionInpaintPipeline`]:
+With 🤗 Diffusers, here is how you can do inpainting:
-```python
-import PIL
-import requests
+1. Load an inpainting checkpoint with the [`AutoPipelineForInpainting`] class. This'll automatically detect the appropriate pipeline class to load based on the checkpoint:
+
+```py
import torch
-from io import BytesIO
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image
-from diffusers import StableDiffusionInpaintPipeline
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16
+).to("cuda")
+pipeline.enable_model_cpu_offload()
+pipeline.enable_xformers_memory_efficient_attention()
+```
-pipeline = StableDiffusionInpaintPipeline.from_pretrained(
- "runwayml/stable-diffusion-inpainting",
- torch_dtype=torch.float16,
- use_safetensors=True,
- variant="fp16",
-)
-pipeline = pipeline.to("cuda")
+
+
+You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, it's not necessary to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention).
+
+
+
+2. Load the base and mask images:
+
+```py
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
```
-Download an image and a mask of a dog which you'll eventually replace:
+3. Create a prompt to inpaint the image with and pass it to the pipeline with the base and mask images:
-```python
-def download_image(url):
- response = requests.get(url)
- return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+```py
+prompt = "a black cat with glowing eyes, cute, adorable, disney, pixar, highly detailed, 8k"
+negative_prompt = "bad anatomy, deformed, ugly, disfigured"
+image = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask_image).images[0]
+```
+
+
+
+ base image
+
+
+
+ generated image
+
+
-img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
-mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+## Popular models
+
+[Stable Diffusion Inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting), [Stable Diffusion XL (SDXL) Inpainting](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1), and [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder-inpaint) are among the most popular models for inpainting. SDXL typically produces higher resolution images than Stable Diffusion v1.5, and Kandinsky 2.2 is also capable of generating high-quality images.
+
+### Stable Diffusion Inpainting
+
+Stable Diffusion Inpainting is a latent diffusion model finetuned on 512x512 images on inpainting. It is a good starting point because it is relatively fast and generates good quality images. To use this model for inpainting, you'll need to pass a prompt, base and mask image to the pipeline:
-init_image = download_image(img_url).resize((512, 512))
-mask_image = download_image(mask_url).resize((512, 512))
+```py
+import torch
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+pipeline.enable_model_cpu_offload()
+pipeline.enable_xformers_memory_efficient_attention()
+
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
+
+generator = torch.Generator("cuda").manual_seed(92)
+prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
```
-Now you can create a prompt to replace the mask with something else:
+### Stable Diffusion XL (SDXL) Inpainting
-```python
-prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
-image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+SDXL is a larger and more powerful version of Stable Diffusion v1.5. This model can follow a two-stage model process (though each model can also be used alone); the base model generates an image, and a refiner model takes that image and further enhances its details and quality. Take a look at the [SDXL](sdxl) guide for a more comprehensive guide on how to use SDXL and configure it's parameters.
+
+```py
+import torch
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+pipeline.enable_model_cpu_offload()
+pipeline.enable_xformers_memory_efficient_attention()
+
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
+
+generator = torch.Generator("cuda").manual_seed(92)
+prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
```
-`image` | `mask_image` | `prompt` | output |
-:-------------------------:|:-------------------------:|:-------------------------:|-------------------------:|
- | | ***Face of a yellow cat, high resolution, sitting on a park bench*** | |
+### Kandinsky 2.2 Inpainting
+The Kandinsky model family is similar to SDXL because it uses two models as well; the image prior model creates image embeddings, and the diffusion model generates images from them. You can load the image prior and diffusion model separately, but the easiest way to use Kandinsky 2.2 is to load it into the [`AutoPipelineForInpainting`] class which uses the [`KandinskyV22InpaintCombinedPipeline`] under the hood.
-
+```py
+import torch
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image
-A previous experimental implementation of inpainting used a different, lower-quality process. To ensure backwards compatibility, loading a pretrained pipeline that doesn't contain the new model will still apply the old inpainting method.
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16
+).to("cuda")
+pipeline.enable_model_cpu_offload()
+pipeline.enable_xformers_memory_efficient_attention()
-
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
+
+generator = torch.Generator("cuda").manual_seed(92)
+prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
+```
-Check out the Spaces below to try out image inpainting yourself!
+
+
+
+ base image
+
+
+
+ Stable Diffusion Inpainting
+
+
+
+ Stable Diffusion XL Inpainting
+
+
+
+ Kandinsky 2.2 Inpainting
+
+
+
+## Configure pipeline parameters
+
+Image features - like quality and "creativity" - are dependent on pipeline parameters. Knowing what these parameters do is important for getting the results you want. Let's take a look at the most important parameters and see how changing them affects the output.
+
+## Strength
+
+`strength` is a measure of how much noise is added to the base image, which influences how similar the output is to the base image.
+
+* 📈 a high `strength` value means more noise is added to an image and the denoising process takes longer, but you'll get higher quality images that are more different from the base image
+* 📉 a low `strength` value means less noise is added to an image and the denoising process is faster, but the image quality may not be as great and the generated image resembles the base image more
+
+```py
+import torch
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+pipeline.enable_model_cpu_offload()
+pipeline.enable_xformers_memory_efficient_attention()
+
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
+
+prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.6).images[0]
+```
+
+
+
+
+ strength = 0.6
+
+
+
+ strength = 0.8
+
+
+
+ strength = 1.0
+
+
+
+## Guidance scale
+
+`guidance_scale` affects how aligned the text prompt and generated image are.
+
+* 📈 a high `guidance_scale` value means the prompt and generated image are closely aligned, so the output is a stricter interpretation of the prompt
+* 📉 a low `guidance_scale` value means the prompt and generated image are more loosely aligned, so the output may be more varied from the prompt
+
+You can use `strength` and `guidance_scale` together for more control over how expressive the model is. For example, a combination high `strength` and `guidance_scale` values gives the model the most creative freedom.
+
+```py
+import torch
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+pipeline.enable_model_cpu_offload()
+pipeline.enable_xformers_memory_efficient_attention()
+
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
+
+prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, guidance_scale=2.5).images[0]
+```
-
+
+
+
+ guidance_scale = 2.5
+
+
+
+ guidance_scale = 7.5
+
+
+
+ guidance_scale = 12.5
+
+
+
+### Negative prompt
+
+A negative prompt assumes the opposite role of a prompt; it guides the model away from generating certain things in an image. This is useful for quickly improving image quality and preventing the model from generating things you don't want.
+
+```py
+import torch
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+pipeline.enable_model_cpu_offload()
+pipeline.enable_xformers_memory_efficient_attention()
-## Preserving the Unmasked Area of the Image
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
+
+prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
+negative_prompt = "bad architecture, unstable, poor details, blurry"
+image = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask_image).images[0]
+image
+```
-Generally speaking, [`StableDiffusionInpaintPipeline`] (and other inpainting pipelines) will change the unmasked part of the image as well. If this behavior is undesirable, you can force the unmasked area to remain the same as follows:
+
-```python
+## Preserve unmasked areas
+
+The [`AutoPipelineForInpainting`] (and other inpainting pipelines) generally changes the unmasked parts of an image to create a more natural transition between the masked and unmasked region. If this behavior is undesirable, you can force the unmasked area to remain the same. However, forcing the unmasked portion of the image to remain the same may result in some unusual transitions between the unmasked and masked areas.
+
+```py
import PIL
import numpy as np
import torch
-from diffusers import StableDiffusionInpaintPipeline
+from diffusers import AutoPipelineForInpainting
from diffusers.utils import load_image
device = "cuda"
-pipeline = StableDiffusionInpaintPipeline.from_pretrained(
+pipeline = AutoPipelineForInpainting.from_pretrained(
"runwayml/stable-diffusion-inpainting",
torch_dtype=torch.float16,
)
@@ -121,4 +323,257 @@ unmasked_unchanged_image = PIL.Image.fromarray(unmasked_unchanged_image_arr.roun
unmasked_unchanged_image.save("force_unmasked_unchanged.png")
```
-Forcing the unmasked portion of the image to remain the same might result in some weird transitions between the unmasked and masked areas, since the model will typically change the masked and unmasked areas to make the transition more natural.
+## Chained inpainting pipelines
+
+[`AutoPipelineForInpainting`] can be chained with other 🤗 Diffusers pipelines to edit their outputs. This is often useful for improving the output quality from your other diffusion pipelines, and if you're using multiple pipelines, it can be more memory-efficient to chain them together to keep the outputs in latent space and reuse the same pipeline components.
+
+### Text-to-image-to-inpaint
+
+Chaining a text-to-image and inpainting pipeline allows you to inpaint the generated image, and you don't have to provide a base image to begin with. This makes it convenient to edit your favorite text-to-image outputs without having to generate an entirely new image.
+
+Start with the text-to-image pipeline to create a castle:
+
+```py
+import torch
+from diffusers import AutoPipelineForText2Image, AutoPipelineForInpainting
+from diffusers.utils import load_image
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+).to("cuda")
+pipeline.enable_model_cpu_offload()
+pipeline.enable_xformers_memory_efficient_attention()
+
+image = pipeline("concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k").images[0]
+```
+
+Load the mask image of the output from above:
+
+```py
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_text-chain-mask.png").convert("RGB")
+```
+
+And let's inpaint the masked area with a waterfall:
+
+```py
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+pipeline.enable_model_cpu_offload()
+pipeline.enable_xformers_memory_efficient_attention()
+
+prompt = "digital painting of a fantasy waterfall, cloudy"
+image = pipeline(prompt=prompt, image=image, mask_image=mask_image).images[0]
+image
+```
+
+
+
+
+ text-to-image
+
+
+
+ inpaint
+
+
+
+
+### Inpaint-to-image-to-image
+
+You can also chain an inpainting pipeline before another pipeline like image-to-image or an upscaler to improve the quality.
+
+Begin by inpainting an image:
+
+```py
+import torch
+from diffusers import AutoPipelineForInpainting, AutoPipelineForImage2Image
+from diffusers.utils import load_image
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+pipeline.enable_model_cpu_offload()
+pipeline.enable_xformers_memory_efficient_attention()
+
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
+
+prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+
+# resize image to 1024x1024 for SDXL
+image = image.resize((1024, 1024))
+```
+
+Now let's pass the image to another inpainting pipeline with SDXL's refiner model to enhance the image details and quality:
+
+```py
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+pipeline.enable_model_cpu_offload()
+pipeline.enable_xformers_memory_efficient_attention()
+
+image = pipeline(prompt=prompt, image=image, mask_image=mask_image, output_type="latent").images[0]
+```
+
+
+
+It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in latent space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE. For example, in the [Text-to-image-to-inpaint](#text-to-image-to-inpaint) section, Kandinsky 2.2 uses a different VAE class than the Stable Diffusion model so it won't work. But if you use Stable Diffusion v1.5 for both pipelines, then you can keep everything in latent space because they both use [`AutoencoderKL`].
+
+
+
+Finally, you can pass this image to an image-to-image pipeline to put the finishing touches on it. It is more efficient to use the [`~AutoPipelineForImage2Image.from_pipe`] method to reuse the existing pipeline components, and avoid unnecessarily loading all the pipeline components into memory again.
+
+```py
+pipeline = AutoPipelineForImage2Image.from_pipe(pipeline)
+pipeline.enable_xformers_memory_efficient_attention()
+
+image = pipeline(prompt=prompt, image=image).images[0]
+```
+
+
+
+
+ initial image
+
+
+
+ inpaint
+
+
+
+ image-to-image
+
+
+
+Image-to-image and inpainting are actually very similar tasks. Image-to-image generates a new image that resembles the existing provided image. Inpainting does the same thing, but it only transforms the image area defined by the mask and the rest of the image is unchanged. You can think of inpainting as a more precise tool for making specific changes and image-to-image has a broader scope for making more sweeping changes.
+
+## Control image generation
+
+Getting an image to look exactly the way you want is challenging because the denoising process is random. While you can control certain aspects of generation by configuring parameters like `negative_prompt`, there are better and more efficient methods for controlling image generation.
+
+### Prompt weighting
+
+Prompt weighting provides a quantifiable way to scale the representation of concepts in a prompt. You can use it to increase or decrease the magnitude of the text embedding vector for each concept in the prompt, which subsequently determines how much of each concept is generated. The [Compel](https://github.com/damian0815/compel) library offers an intuitive syntax for scaling the prompt weights and generating the embeddings. Learn how to create the embeddings in the [Prompt weighting](../using-diffusers/weighted_prompts) guide.
+
+Once you've generated the embeddings, pass them to the `prompt_embeds` (and `negative_prompt_embeds` if you're using a negative prompt) parameter in the [`AutoPipelineForInpainting`]. The embeddings replace the `prompt` parameter:
+
+```py
+import torch
+from diffusers import AutoPipelineForInpainting
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16,
+).to("cuda")
+pipeline.enable_model_cpu_offload()
+pipeline.enable_xformers_memory_efficient_attention()
+
+image = pipeline(prompt_emebds=prompt_embeds, # generated from Compel
+ negative_prompt_embeds, # generated from Compel
+ image=init_image,
+ mask_image=mask_image
+).images[0]
+```
+
+### ControlNet
+
+ControlNet models are used with other diffusion models like Stable Diffusion, and they provide an even more flexible and accurate way to control how an image is generated. A ControlNet accepts an additional conditioning image input that guides the diffusion model to preserve the features in it.
+
+For example, let's condition an image with a ControlNet pretrained on inpaint images:
+
+```py
+import torch
+import numpy as np
+from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
+from diffusers.utils import load_image
+
+# load ControlNet
+controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16, variant="fp16")
+
+# pass ControlNet to the pipeline
+pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+pipeline.enable_model_cpu_offload()
+pipeline.enable_xformers_memory_efficient_attention()
+
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
+
+# prepare control image
+def make_inpaint_condition(init_image, mask_image):
+ init_image = np.array(init_image.convert("RGB")).astype(np.float32) / 255.0
+ mask_image = np.array(mask_image.convert("L")).astype(np.float32) / 255.0
+
+ assert init_image.shape[0:1] == mask_image.shape[0:1], "image and image_mask must have the same image size"
+ init_image[mask_image > 0.5] = -1.0 # set as masked pixel
+ init_image = np.expand_dims(init_image, 0).transpose(0, 3, 1, 2)
+ init_image = torch.from_numpy(init_image)
+ return init_image
+
+control_image = make_inpaint_condition(init_image, mask_image)
+```
+
+Now generate an image from the base, mask and control images. You'll notice features of the base image are strongly preserved in the generated image.
+
+```py
+prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, control_image=control_image).images[0]
+image
+```
+
+You can take this a step further and chain it with an image-to-image pipeline to apply a new [style](https://huggingface.co/nitrosocke/elden-ring-diffusion):
+
+```py
+from diffusers import AutoPipelineForImage2Image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "nitrosocke/elden-ring-diffusion", torch_dtype=torch.float16,
+).to("cuda")
+pipeline.enable_model_cpu_offload()
+pipeline.enable_xformers_memory_efficient_attention()
+
+prompt = "elden ring style castle" # include the token "elden ring style" in the prompt
+negative_prompt = "bad architecture, deformed, disfigured, poor details"
+
+image = pipeline(prompt, negative_prompt=negative_prompt, image=image).images[0]
+image
+```
+
+
+
+
+ initial image
+
+
+
+ ControlNet inpaint
+
+
+
+ image-to-image
+
+
+
+## Optimize
+
+It can be difficult and slow to run diffusion models if you're resource constrained, but it dosen't have to be with a few optimization tricks. One of the biggest (and easiest) optimizations you can enable is switching to memory-efficient attention. If you're using PyTorch 2.0, [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention) is automatically enabled and you don't need to do anything else. For non-PyTorch 2.0 users, you can install and use [xFormers](../optimization/xformers)'s implementation of memory-efficient attention. Both options reduce memory usage and accelerate inference.
+
+You can also offload the model to the GPU to save even more memory:
+
+```diff
++ pipeline.enable_xformers_memory_efficient_attention()
++ pipeline.enable_model_cpu_offload()
+```
+
+To speed-up your inference code even more, use [`torch_compile`](../optimization/torch2.0#torch.compile). You should wrap `torch.compile` around the most intensive component in the pipeline which is typically the UNet:
+
+```py
+pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+```
+
+Learn more in the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/write_own_pipeline.md b/docs/source/en/using-diffusers/write_own_pipeline.md
index 42b3e4d6761d..a9243a7b9adc 100644
--- a/docs/source/en/using-diffusers/write_own_pipeline.md
+++ b/docs/source/en/using-diffusers/write_own_pipeline.md
@@ -112,7 +112,7 @@ As you can see, this is already more complex than the DDPM pipeline which only c
-💡 Read the [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) blog for more details about how the VAE, UNet, and text encoder models.
+💡 Read the [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) blog for more details about how the VAE, UNet, and text encoder models work.
@@ -214,7 +214,7 @@ Next, generate some initial random noise as a starting point for the diffusion p
```py
>>> latents = torch.randn(
-... (batch_size, unet.in_channels, height // 8, width // 8),
+... (batch_size, unet.config.in_channels, height // 8, width // 8),
... generator=generator,
... )
>>> latents = latents.to(torch_device)
diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml
index 895273d851f3..41d5e95a4230 100644
--- a/docs/source/zh/_toctree.yml
+++ b/docs/source/zh/_toctree.yml
@@ -3,6 +3,8 @@
title: 🧨 Diffusers
- local: quicktour
title: 快速入门
+ - local: stable_diffusion
+ title: 有效和高效的扩散
- local: installation
title: 安装
title: 开始
diff --git a/docs/source/zh/stable_diffusion.md b/docs/source/zh/stable_diffusion.md
new file mode 100644
index 000000000000..e28607b09032
--- /dev/null
+++ b/docs/source/zh/stable_diffusion.md
@@ -0,0 +1,264 @@
+
+
+# 有效且高效的扩散
+
+[[open-in-colab]]
+
+让 [`DiffusionPipeline`] 生成特定风格或包含你所想要的内容的图像可能会有些棘手。 通常情况下,你需要多次运行 [`DiffusionPipeline`] 才能得到满意的图像。但是从无到有生成图像是一个计算密集的过程,特别是如果你要一遍又一遍地进行推理运算。
+
+这就是为什么从pipeline中获得最高的 *computational* (speed) 和 *memory* (GPU RAM) 非常重要 ,以减少推理周期之间的时间,从而使迭代速度更快。
+
+
+本教程将指导您如何通过 [`DiffusionPipeline`] 更快、更好地生成图像。
+
+
+首先,加载 [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) 模型:
+
+```python
+from diffusers import DiffusionPipeline
+
+model_id = "runwayml/stable-diffusion-v1-5"
+pipeline = DiffusionPipeline.from_pretrained(model_id, use_safetensors=True)
+```
+
+本教程将使用的提示词是 [`portrait photo of a old warrior chief`] ,但是你可以随心所欲的想象和构造自己的提示词:
+
+```python
+prompt = "portrait photo of a old warrior chief"
+```
+
+## 速度
+
+
+
+💡 如果你没有 GPU, 你可以从像 [Colab](https://colab.research.google.com/) 这样的 GPU 提供商获取免费的 GPU !
+
+
+
+加速推理的最简单方法之一是将 pipeline 放在 GPU 上 ,就像使用任何 PyTorch 模块一样:
+
+```python
+pipeline = pipeline.to("cuda")
+```
+
+为了确保您可以使用相同的图像并对其进行改进,使用 [`Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) 方法,然后设置一个随机数种子 以确保其 [复现性](./using-diffusers/reproducibility):
+
+```python
+import torch
+
+generator = torch.Generator("cuda").manual_seed(0)
+```
+
+现在,你可以生成一个图像:
+
+```python
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+非常的令人印象深刻! Let's tweak the second image - 把 `Generator` 的种子设置为 `1` - 添加一些关于年龄的主题文本:
+
+```python
+prompts = [
+ "portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+]
+
+generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))]
+images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images
+make_image_grid(images, 2, 2)
+```
+
+
+
+
+
+## 最后
+
+在本教程中, 您学习了如何优化[`DiffusionPipeline`]以提高计算和内存效率,以及提高生成输出的质量. 如果你有兴趣让你的 pipeline 更快, 可以看一看以下资源:
+
+- 学习 [PyTorch 2.0](./optimization/torch2.0) 和 [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) 可以让推理速度提高 5 - 300% . 在 A100 GPU 上, 推理速度可以提高 50% !
+- 如果你没法用 PyTorch 2, 我们建议你安装 [xFormers](./optimization/xformers)。它的内存高效注意力机制(*memory-efficient attention mechanism*)与PyTorch 1.13.1配合使用,速度更快,内存消耗更少。
+- 其他的优化技术, 如:模型卸载(*model offloading*), 包含在 [这份指南](./optimization/fp16).
diff --git a/examples/community/composable_stable_diffusion.py b/examples/community/composable_stable_diffusion.py
index 8a2263b096c3..996bb3cef8bf 100644
--- a/examples/community/composable_stable_diffusion.py
+++ b/examples/community/composable_stable_diffusion.py
@@ -562,7 +562,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)
diff --git a/examples/community/img2img_inpainting.py b/examples/community/img2img_inpainting.py
index 7ddb3fe89464..8ee8355d49a6 100644
--- a/examples/community/img2img_inpainting.py
+++ b/examples/community/img2img_inpainting.py
@@ -434,7 +434,8 @@ def __call__(
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
diff --git a/examples/community/interpolate_stable_diffusion.py b/examples/community/interpolate_stable_diffusion.py
index 8f33db71b9f3..70e4d025a037 100644
--- a/examples/community/interpolate_stable_diffusion.py
+++ b/examples/community/interpolate_stable_diffusion.py
@@ -372,7 +372,8 @@ def __call__(
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py
index a735ed040b76..ee0cdc461cf5 100644
--- a/examples/community/lpw_stable_diffusion.py
+++ b/examples/community/lpw_stable_diffusion.py
@@ -1088,7 +1088,8 @@ def __call__(
progress_bar.update()
if i % callback_steps == 0:
if callback is not None:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if is_cancelled_callback is not None and is_cancelled_callback():
return None
diff --git a/examples/community/lpw_stable_diffusion_onnx.py b/examples/community/lpw_stable_diffusion_onnx.py
index cdebb81c6d5e..423e6ced4d77 100644
--- a/examples/community/lpw_stable_diffusion_onnx.py
+++ b/examples/community/lpw_stable_diffusion_onnx.py
@@ -846,7 +846,8 @@ def __call__(
# call the callback, if provided
if i % callback_steps == 0:
if callback is not None:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if is_cancelled_callback is not None and is_cancelled_callback():
return None
diff --git a/examples/community/lpw_stable_diffusion_xl.py b/examples/community/lpw_stable_diffusion_xl.py
index aaf1e20a9d7c..66e2ffb159a1 100644
--- a/examples/community/lpw_stable_diffusion_xl.py
+++ b/examples/community/lpw_stable_diffusion_xl.py
@@ -1182,7 +1182,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
diff --git a/examples/community/masked_stable_diffusion_img2img.py b/examples/community/masked_stable_diffusion_img2img.py
index a35b74da426a..0b08086c7da9 100644
--- a/examples/community/masked_stable_diffusion_img2img.py
+++ b/examples/community/masked_stable_diffusion_img2img.py
@@ -202,7 +202,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
scaled = latents / self.vae.config.scaling_factor
diff --git a/examples/community/multilingual_stable_diffusion.py b/examples/community/multilingual_stable_diffusion.py
index ff6c7e68f783..7597efd215af 100644
--- a/examples/community/multilingual_stable_diffusion.py
+++ b/examples/community/multilingual_stable_diffusion.py
@@ -407,7 +407,8 @@ def __call__(
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
diff --git a/examples/community/pipeline_prompt2prompt.py b/examples/community/pipeline_prompt2prompt.py
index 83e7c7d77c9e..7d330c668da9 100644
--- a/examples/community/pipeline_prompt2prompt.py
+++ b/examples/community/pipeline_prompt2prompt.py
@@ -254,7 +254,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# 8. Post-processing
if not output_type == "latent":
diff --git a/examples/community/pipeline_zero1to3.py b/examples/community/pipeline_zero1to3.py
index 2c0eef92c282..3e4e88ea5aa1 100644
--- a/examples/community/pipeline_zero1to3.py
+++ b/examples/community/pipeline_zero1to3.py
@@ -865,7 +865,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# 8. Post-processing
has_nsfw_concept = None
diff --git a/examples/community/run_onnx_controlnet.py b/examples/community/run_onnx_controlnet.py
index aab6f3873ce3..2b1123a4955c 100644
--- a/examples/community/run_onnx_controlnet.py
+++ b/examples/community/run_onnx_controlnet.py
@@ -815,7 +815,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
_latents = latents.cpu().detach().numpy() / 0.18215
diff --git a/examples/community/run_tensorrt_controlnet.py b/examples/community/run_tensorrt_controlnet.py
index 484fc043ed62..724f393eb122 100644
--- a/examples/community/run_tensorrt_controlnet.py
+++ b/examples/community/run_tensorrt_controlnet.py
@@ -919,7 +919,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
_latents = latents.cpu().detach().numpy() / 0.18215
diff --git a/examples/community/seed_resize_stable_diffusion.py b/examples/community/seed_resize_stable_diffusion.py
index 5891b9fb11a8..9318277b8f01 100644
--- a/examples/community/seed_resize_stable_diffusion.py
+++ b/examples/community/seed_resize_stable_diffusion.py
@@ -337,7 +337,8 @@ def __call__(
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
diff --git a/examples/community/speech_to_image_diffusion.py b/examples/community/speech_to_image_diffusion.py
index 55d805bc8c32..63bcfb662517 100644
--- a/examples/community/speech_to_image_diffusion.py
+++ b/examples/community/speech_to_image_diffusion.py
@@ -242,7 +242,8 @@ def __call__(
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
diff --git a/examples/community/stable_diffusion_controlnet_img2img.py b/examples/community/stable_diffusion_controlnet_img2img.py
index 71009fb1aa69..550aa8ba61a3 100644
--- a/examples/community/stable_diffusion_controlnet_img2img.py
+++ b/examples/community/stable_diffusion_controlnet_img2img.py
@@ -951,7 +951,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
diff --git a/examples/community/stable_diffusion_controlnet_inpaint.py b/examples/community/stable_diffusion_controlnet_inpaint.py
index 3cd9f9f0a258..30903bbf66bf 100644
--- a/examples/community/stable_diffusion_controlnet_inpaint.py
+++ b/examples/community/stable_diffusion_controlnet_inpaint.py
@@ -1100,7 +1100,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
diff --git a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
index 341e89398f7d..96ad3c39239d 100644
--- a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
+++ b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
@@ -1081,7 +1081,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
diff --git a/examples/community/stable_diffusion_controlnet_reference.py b/examples/community/stable_diffusion_controlnet_reference.py
index 0814c6b22af9..d786036bd58a 100644
--- a/examples/community/stable_diffusion_controlnet_reference.py
+++ b/examples/community/stable_diffusion_controlnet_reference.py
@@ -802,7 +802,8 @@ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
diff --git a/examples/community/stable_diffusion_ipex.py b/examples/community/stable_diffusion_ipex.py
index bef575559e07..2f8131d6cbc0 100644
--- a/examples/community/stable_diffusion_ipex.py
+++ b/examples/community/stable_diffusion_ipex.py
@@ -817,7 +817,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if output_type == "latent":
image = latents
diff --git a/examples/community/stable_diffusion_reference.py b/examples/community/stable_diffusion_reference.py
index d2b5acda2340..505470574a0b 100644
--- a/examples/community/stable_diffusion_reference.py
+++ b/examples/community/stable_diffusion_reference.py
@@ -770,7 +770,8 @@ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/examples/community/stable_diffusion_repaint.py b/examples/community/stable_diffusion_repaint.py
index 07b3b7811017..ce4f245b31fa 100644
--- a/examples/community/stable_diffusion_repaint.py
+++ b/examples/community/stable_diffusion_repaint.py
@@ -932,7 +932,8 @@ def __call__(
# call the callback, if provided
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
t_last = t
diff --git a/examples/community/stable_diffusion_xl_reference.py b/examples/community/stable_diffusion_xl_reference.py
index a7654f11bcc9..5d2b1c771128 100644
--- a/examples/community/stable_diffusion_xl_reference.py
+++ b/examples/community/stable_diffusion_xl_reference.py
@@ -771,7 +771,8 @@ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
diff --git a/examples/community/wildcard_stable_diffusion.py b/examples/community/wildcard_stable_diffusion.py
index aec79fb8e12e..1a5ea350b857 100644
--- a/examples/community/wildcard_stable_diffusion.py
+++ b/examples/community/wildcard_stable_diffusion.py
@@ -389,7 +389,8 @@ def __call__(
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py
index 34e8c69ff64b..68162d7824ab 100644
--- a/examples/controlnet/train_controlnet_flax.py
+++ b/examples/controlnet/train_controlnet_flax.py
@@ -907,17 +907,10 @@ def compute_loss(params, minibatch, sample_rng):
if args.snr_gamma is not None:
snr = jnp.array(compute_snr(timesteps))
- base_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
if noise_scheduler.config.prediction_type == "v_prediction":
- snr_loss_weights = base_weights + 1
- else:
- # Epsilon and sample prediction use the base weights.
- snr_loss_weights = base_weights
- # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
- # When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
- # If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
- snr_loss_weights[snr == 0] = 1.0
-
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
loss = loss * snr_loss_weights
loss = loss.mean()
diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py
index 60d8d6723dcf..4773446a615b 100644
--- a/examples/custom_diffusion/train_custom_diffusion.py
+++ b/examples/custom_diffusion/train_custom_diffusion.py
@@ -207,7 +207,7 @@ def __init__(
with open(concept["class_prompt"], "r") as f:
class_prompt = f.read().splitlines()
- class_img_path = [(x, y) for (x, y) in zip(class_images_path, class_prompt)]
+ class_img_path = list(zip(class_images_path, class_prompt))
self.class_images_path.extend(class_img_path[:num_class_images])
random.shuffle(self.instance_images_path)
@@ -1075,30 +1075,30 @@ def main(args):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
+ initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
- resume_global_step = global_step * args.gradient_accumulation_steps
+ initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
- resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
-
- # Only show the progress bar once on each machine.
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
- progress_bar.set_description("Steps")
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
if args.modifier_token is not None:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
- # Skip steps until we reach the resumed step
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
- if step % args.gradient_accumulation_steps == 0:
- progress_bar.update(1)
- continue
-
with accelerator.accumulate(unet), accelerator.accumulate(text_encoder):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
@@ -1214,50 +1214,52 @@ def main(args):
if global_step >= args.max_train_steps:
break
- if accelerator.is_main_process:
- images = []
+ if accelerator.is_main_process:
+ images = []
- if args.validation_prompt is not None and global_step % args.validation_steps == 0:
- logger.info(
- f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
- f" {args.validation_prompt}."
- )
- # create pipeline
- pipeline = DiffusionPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- unet=accelerator.unwrap_model(unet),
- text_encoder=accelerator.unwrap_model(text_encoder),
- tokenizer=tokenizer,
- revision=args.revision,
- torch_dtype=weight_dtype,
- )
- pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
- pipeline = pipeline.to(accelerator.device)
- pipeline.set_progress_bar_config(disable=True)
-
- # run inference
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
- images = [
- pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[0]
- for _ in range(args.num_validation_images)
- ]
-
- for tracker in accelerator.trackers:
- if tracker.name == "tensorboard":
- np_images = np.stack([np.asarray(img) for img in images])
- tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
- if tracker.name == "wandb":
- tracker.log(
- {
- "validation": [
- wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
- for i, image in enumerate(images)
- ]
- }
- )
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=accelerator.unwrap_model(unet),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ tokenizer=tokenizer,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ images = [
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[
+ 0
+ ]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
- del pipeline
- torch.cuda.empty_cache()
+ del pipeline
+ torch.cuda.empty_cache()
# Save the custom diffusion layers
accelerator.wait_for_everyone()
diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py
index 71510b18c8a3..606cc5c6cfdd 100644
--- a/examples/dreambooth/train_dreambooth.py
+++ b/examples/dreambooth/train_dreambooth.py
@@ -1178,30 +1178,30 @@ def compute_text_embeddings(prompt):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
+ initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
- resume_global_step = global_step * args.gradient_accumulation_steps
+ initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
- resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
-
- # Only show the progress bar once on each machine.
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
- progress_bar.set_description("Steps")
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
if args.train_text_encoder:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
- # Skip steps until we reach the resumed step
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
- if step % args.gradient_accumulation_steps == 0:
- progress_bar.update(1)
- continue
-
with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py
index dc90d10f2b26..47de88f338d1 100644
--- a/examples/dreambooth/train_dreambooth_lora.py
+++ b/examples/dreambooth/train_dreambooth_lora.py
@@ -854,7 +854,7 @@ def main(args):
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
- # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
+ # - up blocks (2x attention layers) * (3x transformer layers) * (3x up blocks) = 18
# => 32 layers
# Set correct lora layers
@@ -1108,30 +1108,30 @@ def compute_text_embeddings(prompt):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
+ initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
- resume_global_step = global_step * args.gradient_accumulation_steps
+ initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
- resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
-
- # Only show the progress bar once on each machine.
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
- progress_bar.set_description("Steps")
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
if args.train_text_encoder:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
- # Skip steps until we reach the resumed step
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
- if step % args.gradient_accumulation_steps == 0:
- progress_bar.update(1)
- continue
-
with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index 24dbf4313662..ac59bba6c847 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -1048,18 +1048,25 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
+ initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
- resume_global_step = global_step * args.gradient_accumulation_steps
+ initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
- resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
- # Only show the progress bar once on each machine.
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
- progress_bar.set_description("Steps")
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
@@ -1067,12 +1074,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
text_encoder_one.train()
text_encoder_two.train()
for step, batch in enumerate(train_dataloader):
- # Skip steps until we reach the resumed step
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
- if step % args.gradient_accumulation_steps == 0:
- progress_bar.update(1)
- continue
-
with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
index 4d0b9bef55f1..e2d9b2105160 100644
--- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
+++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
@@ -726,6 +726,9 @@ def preprocess_images(examples):
text_encoder_1.requires_grad_(False)
text_encoder_2.requires_grad_(False)
+ # Set UNet to trainable.
+ unet.train()
+
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(text_encoders, tokenizers, prompt):
prompt_embeds_list = []
@@ -933,29 +936,28 @@ def collate_fn(examples):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
+ initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
- resume_global_step = global_step * args.gradient_accumulation_steps
+ initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
- resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
-
- # Only show the progress bar once on each machine.
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
- progress_bar.set_description("Steps")
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
for epoch in range(first_epoch, args.num_train_epochs):
- unet.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
- # Skip steps until we reach the resumed step
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
- if step % args.gradient_accumulation_steps == 0:
- progress_bar.update(1)
- continue
-
with accelerator.accumulate(unet):
# We want to learn the denoising process w.r.t the edited images which
# are conditioned on the original image (which was edited) and the edit instruction.
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
index dd79c88f8a76..4ca95ecebea9 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
@@ -512,6 +512,9 @@ def deepspeed_zero_init_disabled_context_manager():
vae.requires_grad_(False)
image_encoder.requires_grad_(False)
+ # Set unet to trainable.
+ unet.train()
+
# Create EMA for the unet.
if args.use_ema:
ema_unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
@@ -727,27 +730,28 @@ def collate_fn(examples):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
+ initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
- resume_global_step = global_step * args.gradient_accumulation_steps
+ initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
- resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
- progress_bar.set_description("Steps")
for epoch in range(first_epoch, args.num_train_epochs):
- unet.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
- # Skip steps until we reach the resumed step
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
- if step % args.gradient_accumulation_steps == 0:
- progress_bar.update(1)
- continue
-
with accelerator.accumulate(unet):
# Convert images to latent space
images = batch["pixel_values"].to(weight_dtype)
@@ -777,25 +781,13 @@ def collate_fn(examples):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
- base_weight = (
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
- if noise_scheduler.config.prediction_type == "v_prediction":
- # Velocity objective needs to be floored to an SNR weight of one.
- mse_loss_weights = base_weight + 1
- else:
- # Epsilon and sample both use the same loss weights.
- mse_loss_weights = base_weight
-
- # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
- # When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
- # If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
- mse_loss_weights[snr == 0] = 1.0
-
- # We first calculate the original loss. Then we mean over the non-batch dimensions and
- # rebalance the sample-wise losses with their respective loss weights.
- # Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
index 5e5f4b9cbf5d..19245724ecf5 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
@@ -579,29 +579,29 @@ def collate_fn(examples):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
+ initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
- resume_global_step = global_step * args.gradient_accumulation_steps
+ initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
- resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+ else:
+ initial_global_step = 0
- # Only show the progress bar once on each machine.
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
- progress_bar.set_description("Steps")
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
- # Skip steps until we reach the resumed step
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
- if step % args.gradient_accumulation_steps == 0:
- progress_bar.update(1)
- continue
-
with accelerator.accumulate(unet):
# Convert images to latent space
images = batch["pixel_values"].to(weight_dtype)
@@ -631,25 +631,13 @@ def collate_fn(examples):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
- base_weight = (
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
- if noise_scheduler.config.prediction_type == "v_prediction":
- # Velocity objective needs to be floored to an SNR weight of one.
- mse_loss_weights = base_weight + 1
- else:
- # Epsilon and sample both use the same loss weights.
- mse_loss_weights = base_weight
-
- # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
- # When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
- # If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
- mse_loss_weights[snr == 0] = 1.0
-
- # We first calculate the original loss. Then we mean over the non-batch dimensions and
- # rebalance the sample-wise losses with their respective loss weights.
- # Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
index d2aabb948969..7305137218ef 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
@@ -595,30 +595,33 @@ def collate_fn(examples):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
+ initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
- resume_global_step = global_step * args.gradient_accumulation_steps
+ initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
- resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
- # Only show the progress bar once on each machine.
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
- progress_bar.set_description("Steps")
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
clip_mean = clip_mean.to(weight_dtype).to(accelerator.device)
clip_std = clip_std.to(weight_dtype).to(accelerator.device)
+
for epoch in range(first_epoch, args.num_train_epochs):
prior.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
- # Skip steps until we reach the resumed step
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
- if step % args.gradient_accumulation_steps == 0:
- progress_bar.update(1)
- continue
-
with accelerator.accumulate(prior):
# Convert images to latent space
text_input_ids, text_mask, clip_images = (
@@ -661,25 +664,13 @@ def collate_fn(examples):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
- base_weight = (
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
- if noise_scheduler.config.prediction_type == "v_prediction":
- # Velocity objective needs to be floored to an SNR weight of one.
- mse_loss_weights = base_weight + 1
- else:
- # Epsilon and sample both use the same loss weights.
- mse_loss_weights = base_weight
-
- # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
- # When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
- # If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
- mse_loss_weights[snr == 0] = 1.0
-
- # We first calculate the original loss. Then we mean over the non-batch dimensions and
- # rebalance the sample-wise losses with their respective loss weights.
- # Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
index b86df4de600c..d21eaf3dd0b0 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
@@ -517,6 +517,9 @@ def deepspeed_zero_init_disabled_context_manager():
text_encoder.requires_grad_(False)
image_encoder.requires_grad_(False)
+ # Set prior to trainable.
+ prior.train()
+
# Create EMA for the prior.
if args.use_ema:
ema_prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
@@ -741,32 +744,31 @@ def collate_fn(examples):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
+ initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
- resume_global_step = global_step * args.gradient_accumulation_steps
+ initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
- resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+ else:
+ initial_global_step = 0
- # Only show the progress bar once on each machine.
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
- progress_bar.set_description("Steps")
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
clip_mean = clip_mean.to(weight_dtype).to(accelerator.device)
clip_std = clip_std.to(weight_dtype).to(accelerator.device)
for epoch in range(first_epoch, args.num_train_epochs):
- prior.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
- # Skip steps until we reach the resumed step
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
- if step % args.gradient_accumulation_steps == 0:
- progress_bar.update(1)
- continue
-
with accelerator.accumulate(prior):
# Convert images to latent space
text_input_ids, text_mask, clip_images = (
@@ -809,25 +811,13 @@ def collate_fn(examples):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
- base_weight = (
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
- if noise_scheduler.config.prediction_type == "v_prediction":
- # Velocity objective needs to be floored to an SNR weight of one.
- mse_loss_weights = base_weight + 1
- else:
- # Epsilon and sample both use the same loss weights.
- mse_loss_weights = base_weight
-
- # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
- # When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
- # If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
- mse_loss_weights[snr == 0] = 1.0
-
- # We first calculate the original loss. Then we mean over the non-batch dimensions and
- # rebalance the sample-wise losses with their respective loss weights.
- # Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
index 15c17063bd68..f7100788cde2 100644
--- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
+++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
@@ -848,24 +848,13 @@ def collate_fn(examples):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
- base_weight = (
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
- if noise_scheduler.config.prediction_type == "v_prediction":
- # velocity objective prediction requires SNR weights to be floored to a min value of 1.
- mse_loss_weights = base_weight + 1
- else:
- # Epsilon and sample prediction use the base weights.
- mse_loss_weights = base_weight
-
- # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
- # When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
- # If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
- mse_loss_weights[snr == 0] = 1.0
-
- # We first calculate the original loss. Then we mean over the non-batch dimensions and
- # rebalance the sample-wise losses with their respective loss weights.
- # Finally, we take the mean of the rebalanced loss.
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
diff --git a/examples/research_projects/rdm/pipeline_rdm.py b/examples/research_projects/rdm/pipeline_rdm.py
index 3e2653c5423d..28b4cacb8319 100644
--- a/examples/research_projects/rdm/pipeline_rdm.py
+++ b/examples/research_projects/rdm/pipeline_rdm.py
@@ -432,7 +432,8 @@ def __call__(
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
diff --git a/examples/research_projects/sdxl_flax/README.md b/examples/research_projects/sdxl_flax/README.md
new file mode 100644
index 000000000000..fca21912982a
--- /dev/null
+++ b/examples/research_projects/sdxl_flax/README.md
@@ -0,0 +1,243 @@
+# Stable Diffusion XL for JAX + TPUv5e
+
+[TPU v5e](https://cloud.google.com/blog/products/compute/how-cloud-tpu-v5e-accelerates-large-scale-ai-inference) is a new generation of TPUs from Google Cloud. It is the most cost-effective, versatile, and scalable Cloud TPU to date. This makes them ideal for serving and scaling large diffusion models.
+
+[JAX](https://github.com/google/jax) is a high-performance numerical computation library that is well-suited to develop and deploy diffusion models:
+
+- **High performance**. All JAX operations are implemented in terms of operations in [XLA](https://www.tensorflow.org/xla/) - the Accelerated Linear Algebra compiler
+
+- **Compilation**. JAX uses just-in-time (jit) compilation of JAX Python functions so it can be executed efficiently in XLA. In order to get the best performance, we must use static shapes for jitted functions, this is because JAX transforms work by tracing a function and to determine its effect on inputs of a specific shape and type. When a new shape is introduced to an already compiled function, it retriggers compilation on the new shape, which can greatly reduce performance. **Note**: JIT compilation is particularly well-suited for text-to-image generation because all inputs and outputs (image input / output sizes) are static.
+
+- **Parallelization**. Workloads can be scaled across multiple devices using JAX's [pmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html), which expresses single-program multiple-data (SPMD) programs. Applying pmap to a function will compile a function with XLA, then execute in parallel on XLA devices. For text-to-image generation workloads this means that increasing the number of images rendered simultaneously is straightforward to implement and doesn't compromise performance.
+
+👉 Try it out for yourself:
+
+[](https://huggingface.co/spaces/google/sdxl)
+
+## Stable Diffusion XL pipeline in JAX
+
+Upon having access to a TPU VM (TPUs higher than version 3), you should first install
+a TPU-compatible version of JAX:
+```
+pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
+```
+
+Next, we can install [flax](https://github.com/google/flax) and the diffusers library:
+
+```
+pip install flax diffusers transformers
+```
+
+In [sdxl_single.py](./sdxl_single.py) we give a simple example of how to write a text-to-image generation pipeline in JAX using [StabilityAI's Stable Diffusion XL](stabilityai/stable-diffusion-xl-base-1.0).
+
+Let's explain it step-by-step:
+
+**Imports and Setup**
+
+```python
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.jax_utils import replicate
+from diffusers import FlaxStableDiffusionXLPipeline
+
+from jax.experimental.compilation_cache import compilation_cache as cc
+cc.initialize_cache("/tmp/sdxl_cache")
+import time
+
+NUM_DEVICES = jax.device_count()
+```
+
+First, we import the necessary libraries:
+- `jax` is provides the primitives for TPU operations
+- `flax.jax_utils` contains some useful utility functions for `Flax`, a neural network library built on top of JAX
+- `diffusers` has all the code that is relevant for SDXL.
+- We also initialize a cache to speed up the JAX model compilation.
+- We automatically determine the number of available TPU devices.
+
+**1. Downloading Model and Loading Pipeline**
+
+```python
+pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
+)
+```
+Here, a pre-trained model `stable-diffusion-xl-base-1.0` from the namespace `stabilityai` is loaded. It returns a pipeline for inference and its parameters.
+
+**2. Casting Parameter Types**
+
+```python
+scheduler_state = params.pop("scheduler")
+params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
+params["scheduler"] = scheduler_state
+```
+This section adjusts the data types of the model parameters.
+We convert all parameters to `bfloat16` to speed-up the computation with model weights.
+**Note** that the scheduler parameters are **not** converted to `blfoat16` as the loss
+in precision is degrading the pipeline's performance too significantly.
+
+**3. Define Inputs to Pipeline**
+
+```python
+default_prompt = ...
+default_neg_prompt = ...
+default_seed = 33
+default_guidance_scale = 5.0
+default_num_steps = 25
+```
+Here, various default inputs for the pipeline are set, including the prompt, negative prompt, random seed, guidance scale, and the number of inference steps.
+
+**4. Tokenizing Inputs**
+
+```python
+def tokenize_prompt(prompt, neg_prompt):
+ prompt_ids = pipeline.prepare_inputs(prompt)
+ neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
+ return prompt_ids, neg_prompt_ids
+```
+This function tokenizes the given prompts. It's essential because the text encoders of SDXL don't understand raw text; they work with numbers. Tokenization converts text to numbers.
+
+**5. Parallelization and Replication**
+
+```python
+p_params = replicate(params)
+
+def replicate_all(prompt_ids, neg_prompt_ids, seed):
+ ...
+```
+To utilize JAX's parallel capabilities, the parameters and input tensors are duplicated across devices. The `replicate_all` function also ensures that every device produces a different image by creating a unique random seed for each device.
+
+**6. Putting Everything Together**
+
+```python
+def generate(...):
+ ...
+```
+This function integrates all the steps to produce the desired outputs from the model. It takes in prompts, tokenizes them, replicates them across devices, runs them through the pipeline, and converts the images to a format that's more interpretable (PIL format).
+
+**7. Compilation Step**
+
+```python
+start = time.time()
+print(f"Compiling ...")
+generate(default_prompt, default_neg_prompt)
+print(f"Compiled in {time.time() - start}")
+```
+The initial run of the `generate` function will be slow because JAX compiles the function during this call. By running it once here, subsequent calls will be much faster. This section measures and prints the compilation time.
+
+**8. Fast Inference**
+
+```python
+start = time.time()
+prompt = ...
+neg_prompt = ...
+images = generate(prompt, neg_prompt)
+print(f"Inference in {time.time() - start}")
+```
+Now that the function is compiled, this section shows how to use it for fast inference. It measures and prints the inference time.
+
+In summary, the code demonstrates how to load a pre-trained model using Flax and JAX, prepare it for inference, and run it efficiently using JAX's capabilities.
+
+## Ahead of Time (AOT) Compilation
+
+FlaxStableDiffusionXLPipeline takes care of parallelization across multiple devices using jit. Now let's build parallelization ourselves.
+
+For this we will be using a JAX feature called [Ahead of Time](https://jax.readthedocs.io/en/latest/aot.html) (AOT) lowering and compilation. AOT allows to fully compile prior to execution time and have control over different parts of the compilation process.
+
+In [sdxl_single_aot.py](./sdxl_single_aot.py) we give a simple example of how to write our own parallelization logic for text-to-image generation pipeline in JAX using [StabilityAI's Stable Diffusion XL](stabilityai/stable-diffusion-xl-base-1.0)
+
+We add a `aot_compile` function that compiles the `pipeline._generate` function
+telling JAX which input arguments are static, that is, arguments that
+are known at compile time and won't change. In our case, it is num_inference_steps,
+height, width and return_latents.
+
+Once the function is compiled, these parameters are ommited from future calls and
+cannot be changed without modifying the code and recompiling.
+
+```python
+def aot_compile(
+ prompt=default_prompt,
+ negative_prompt=default_neg_prompt,
+ seed=default_seed,
+ guidance_scale=default_guidance_scale,
+ num_inference_steps=default_num_steps
+):
+ prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
+ prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
+ g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
+ g = g[:, None]
+
+ return pmap(
+ pipeline._generate,static_broadcasted_argnums=[3, 4, 5, 9]
+ ).lower(
+ prompt_ids,
+ p_params,
+ rng,
+ num_inference_steps, # num_inference_steps
+ height, # height
+ width, # width
+ g,
+ None,
+ neg_prompt_ids,
+ False # return_latents
+ ).compile()
+````
+
+Next we can compile the generate function by executing `aot_compile`.
+
+```python
+start = time.time()
+print("Compiling ...")
+p_generate = aot_compile()
+print(f"Compiled in {time.time() - start}")
+```
+And again we put everything together in a `generate` function.
+
+```python
+def generate(
+ prompt,
+ negative_prompt,
+ seed=default_seed,
+ guidance_scale=default_guidance_scale
+):
+ prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
+ prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
+ g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
+ g = g[:, None]
+ images = p_generate(
+ prompt_ids,
+ p_params,
+ rng,
+ g,
+ None,
+ neg_prompt_ids)
+
+ # convert the images to PIL
+ images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
+ return pipeline.numpy_to_pil(np.array(images))
+```
+
+The first forward pass after AOT compilation still takes a while longer than
+subsequent passes, this is because on the first pass, JAX uses Python dispatch, which
+Fills the C++ dispatch cache.
+When using jit, this extra step is done automatically, but when using AOT compilation,
+it doesn't happen until the function call is made.
+
+```python
+start = time.time()
+prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
+neg_prompt = "cartoon, illustration, animation. face. male, female"
+images = generate(prompt, neg_prompt)
+print(f"First inference in {time.time() - start}")
+```
+
+From this point forward, any calls to generate should result in a faster inference
+time and it won't change.
+
+```python
+start = time.time()
+prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
+neg_prompt = "cartoon, illustration, animation. face. male, female"
+images = generate(prompt, neg_prompt)
+print(f"Inference in {time.time() - start}")
+```
diff --git a/examples/research_projects/sdxl_flax/sdxl_single.py b/examples/research_projects/sdxl_flax/sdxl_single.py
new file mode 100644
index 000000000000..5b9b862d99b5
--- /dev/null
+++ b/examples/research_projects/sdxl_flax/sdxl_single.py
@@ -0,0 +1,106 @@
+# Show best practices for SDXL JAX
+import time
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.jax_utils import replicate
+
+# Let's cache the model compilation, so that it doesn't take as long the next time around.
+from jax.experimental.compilation_cache import compilation_cache as cc
+
+from diffusers import FlaxStableDiffusionXLPipeline
+
+
+cc.initialize_cache("/tmp/sdxl_cache")
+
+
+NUM_DEVICES = jax.device_count()
+
+# 1. Let's start by downloading the model and loading it into our pipeline class
+# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
+# will have to be passed to the pipeline during inference
+pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
+)
+
+# 2. We cast all parameters to bfloat16 EXCEPT the scheduler which we leave in
+# float32 to keep maximal precision
+scheduler_state = params.pop("scheduler")
+params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
+params["scheduler"] = scheduler_state
+
+# 3. Next, we define the different inputs to the pipeline
+default_prompt = "a colorful photo of a castle in the middle of a forest with trees and bushes, by Ismail Inceoglu, shadows, high contrast, dynamic shading, hdr, detailed vegetation, digital painting, digital drawing, detailed painting, a detailed digital painting, gothic art, featured on deviantart"
+default_neg_prompt = "fog, grainy, purple"
+default_seed = 33
+default_guidance_scale = 5.0
+default_num_steps = 25
+
+
+# 4. In order to be able to compile the pipeline
+# all inputs have to be tensors or strings
+# Let's tokenize the prompt and negative prompt
+def tokenize_prompt(prompt, neg_prompt):
+ prompt_ids = pipeline.prepare_inputs(prompt)
+ neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
+ return prompt_ids, neg_prompt_ids
+
+
+# 5. To make full use of JAX's parallelization capabilities
+# the parameters and input tensors are duplicated across devices
+# To make sure every device generates a different image, we create
+# different seeds for each image. The model parameters won't change
+# during inference so we do not wrap them into a function
+p_params = replicate(params)
+
+
+def replicate_all(prompt_ids, neg_prompt_ids, seed):
+ p_prompt_ids = replicate(prompt_ids)
+ p_neg_prompt_ids = replicate(neg_prompt_ids)
+ rng = jax.random.PRNGKey(seed)
+ rng = jax.random.split(rng, NUM_DEVICES)
+ return p_prompt_ids, p_neg_prompt_ids, rng
+
+
+# 6. Let's now put it all together in a generate function
+def generate(
+ prompt,
+ negative_prompt,
+ seed=default_seed,
+ guidance_scale=default_guidance_scale,
+ num_inference_steps=default_num_steps,
+):
+ prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
+ prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
+ images = pipeline(
+ prompt_ids,
+ p_params,
+ rng,
+ num_inference_steps=num_inference_steps,
+ neg_prompt_ids=neg_prompt_ids,
+ guidance_scale=guidance_scale,
+ jit=True,
+ ).images
+
+ # convert the images to PIL
+ images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
+ return pipeline.numpy_to_pil(np.array(images))
+
+
+# 7. Remember that the first call will compile the function and hence be very slow. Let's run generate once
+# so that the pipeline call is compiled
+start = time.time()
+print("Compiling ...")
+generate(default_prompt, default_neg_prompt)
+print(f"Compiled in {time.time() - start}")
+
+# 8. Now the model forward pass will run very quickly, let's try it again
+start = time.time()
+prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
+neg_prompt = "cartoon, illustration, animation. face. male, female"
+images = generate(prompt, neg_prompt)
+print(f"Inference in {time.time() - start}")
+
+for i, image in enumerate(images):
+ image.save(f"castle_{i}.png")
diff --git a/examples/research_projects/sdxl_flax/sdxl_single_aot.py b/examples/research_projects/sdxl_flax/sdxl_single_aot.py
new file mode 100644
index 000000000000..58447fd86daf
--- /dev/null
+++ b/examples/research_projects/sdxl_flax/sdxl_single_aot.py
@@ -0,0 +1,143 @@
+import time
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.jax_utils import replicate
+from jax import pmap
+
+# Let's cache the model compilation, so that it doesn't take as long the next time around.
+from jax.experimental.compilation_cache import compilation_cache as cc
+
+from diffusers import FlaxStableDiffusionXLPipeline
+
+
+cc.initialize_cache("/tmp/sdxl_cache")
+
+
+NUM_DEVICES = jax.device_count()
+
+# 1. Let's start by downloading the model and loading it into our pipeline class
+# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
+# will have to be passed to the pipeline during inference
+pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
+)
+
+# 2. We cast all parameters to bfloat16 EXCEPT the scheduler which we leave in
+# float32 to keep maximal precision
+scheduler_state = params.pop("scheduler")
+params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
+params["scheduler"] = scheduler_state
+
+# 3. Next, we define the different inputs to the pipeline
+default_prompt = "a colorful photo of a castle in the middle of a forest with trees and bushes, by Ismail Inceoglu, shadows, high contrast, dynamic shading, hdr, detailed vegetation, digital painting, digital drawing, detailed painting, a detailed digital painting, gothic art, featured on deviantart"
+default_neg_prompt = "fog, grainy, purple"
+default_seed = 33
+default_guidance_scale = 5.0
+default_num_steps = 25
+width = 1024
+height = 1024
+
+
+# 4. In order to be able to compile the pipeline
+# all inputs have to be tensors or strings
+# Let's tokenize the prompt and negative prompt
+def tokenize_prompt(prompt, neg_prompt):
+ prompt_ids = pipeline.prepare_inputs(prompt)
+ neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
+ return prompt_ids, neg_prompt_ids
+
+
+# 5. To make full use of JAX's parallelization capabilities
+# the parameters and input tensors are duplicated across devices
+# To make sure every device generates a different image, we create
+# different seeds for each image. The model parameters won't change
+# during inference so we do not wrap them into a function
+p_params = replicate(params)
+
+
+def replicate_all(prompt_ids, neg_prompt_ids, seed):
+ p_prompt_ids = replicate(prompt_ids)
+ p_neg_prompt_ids = replicate(neg_prompt_ids)
+ rng = jax.random.PRNGKey(seed)
+ rng = jax.random.split(rng, NUM_DEVICES)
+ return p_prompt_ids, p_neg_prompt_ids, rng
+
+
+# 6. To compile the pipeline._generate function, we must pass all parameters
+# to the function and tell JAX which are static arguments, that is, arguments that
+# are known at compile time and won't change. In our case, it is num_inference_steps,
+# height, width and return_latents.
+# Once the function is compiled, these parameters are ommited from future calls and
+# cannot be changed without modifying the code and recompiling.
+def aot_compile(
+ prompt=default_prompt,
+ negative_prompt=default_neg_prompt,
+ seed=default_seed,
+ guidance_scale=default_guidance_scale,
+ num_inference_steps=default_num_steps,
+):
+ prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
+ prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
+ g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
+ g = g[:, None]
+
+ return (
+ pmap(pipeline._generate, static_broadcasted_argnums=[3, 4, 5, 9])
+ .lower(
+ prompt_ids,
+ p_params,
+ rng,
+ num_inference_steps, # num_inference_steps
+ height, # height
+ width, # width
+ g,
+ None,
+ neg_prompt_ids,
+ False, # return_latents
+ )
+ .compile()
+ )
+
+
+start = time.time()
+print("Compiling ...")
+p_generate = aot_compile()
+print(f"Compiled in {time.time() - start}")
+
+
+# 7. Let's now put it all together in a generate function.
+def generate(prompt, negative_prompt, seed=default_seed, guidance_scale=default_guidance_scale):
+ prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
+ prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
+ g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
+ g = g[:, None]
+ images = p_generate(prompt_ids, p_params, rng, g, None, neg_prompt_ids)
+
+ # convert the images to PIL
+ images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
+ return pipeline.numpy_to_pil(np.array(images))
+
+
+# 8. The first forward pass after AOT compilation still takes a while longer than
+# subsequent passes, this is because on the first pass, JAX uses Python dispatch, which
+# Fills the C++ dispatch cache.
+# When using jit, this extra step is done automatically, but when using AOT compilation,
+# it doesn't happen until the function call is made.
+start = time.time()
+prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
+neg_prompt = "cartoon, illustration, animation. face. male, female"
+images = generate(prompt, neg_prompt)
+print(f"First inference in {time.time() - start}")
+
+# 9. From this point forward, any calls to generate should result in a faster inference
+# time and it won't change.
+start = time.time()
+prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
+neg_prompt = "cartoon, illustration, animation. face. male, female"
+images = generate(prompt, neg_prompt)
+print(f"Inference in {time.time() - start}")
+
+for i, image in enumerate(images):
+ image.save(f"castle_{i}.png")
diff --git a/examples/text_to_image/requirements_sdxl.txt b/examples/text_to_image/requirements_sdxl.txt
index 5d67662fadbe..cdd3336e3617 100644
--- a/examples/text_to_image/requirements_sdxl.txt
+++ b/examples/text_to_image/requirements_sdxl.txt
@@ -4,3 +4,4 @@ transformers>=4.25.1
ftfy
tensorboard
Jinja2
+datasets
diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py
index 535942629314..e216529b2f54 100644
--- a/examples/text_to_image/train_text_to_image.py
+++ b/examples/text_to_image/train_text_to_image.py
@@ -577,9 +577,10 @@ def deepspeed_zero_init_disabled_context_manager():
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
)
- # Freeze vae and text_encoder
+ # Freeze vae and text_encoder and set unet to trainable
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
+ unet.train()
# Create EMA for the unet.
if args.use_ema:
@@ -854,29 +855,29 @@ def collate_fn(examples):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
+ initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
- resume_global_step = global_step * args.gradient_accumulation_steps
+ initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
- resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
- # Only show the progress bar once on each machine.
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
- progress_bar.set_description("Steps")
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
for epoch in range(first_epoch, args.num_train_epochs):
- unet.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
- # Skip steps until we reach the resumed step
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
- if step % args.gradient_accumulation_steps == 0:
- progress_bar.update(1)
- continue
-
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
@@ -928,25 +929,13 @@ def collate_fn(examples):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
- base_weight = (
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
- if noise_scheduler.config.prediction_type == "v_prediction":
- # Velocity objective needs to be floored to an SNR weight of one.
- mse_loss_weights = base_weight + 1
- else:
- # Epsilon and sample both use the same loss weights.
- mse_loss_weights = base_weight
-
- # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
- # When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
- # If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
- mse_loss_weights[snr == 0] = 1.0
-
- # We first calculate the original loss. Then we mean over the non-batch dimensions and
- # rebalance the sample-wise losses with their respective loss weights.
- # Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py
index d4d13a144f38..eac0f18f49f4 100644
--- a/examples/text_to_image/train_text_to_image_lora.py
+++ b/examples/text_to_image/train_text_to_image_lora.py
@@ -429,7 +429,6 @@ def main():
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
-
text_encoder.requires_grad_(False)
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
@@ -690,29 +689,29 @@ def collate_fn(examples):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
+ initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
- resume_global_step = global_step * args.gradient_accumulation_steps
+ initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
- resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+ else:
+ initial_global_step = 0
- # Only show the progress bar once on each machine.
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
- progress_bar.set_description("Steps")
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
- # Skip steps until we reach the resumed step
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
- if step % args.gradient_accumulation_steps == 0:
- progress_bar.update(1)
- continue
-
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
@@ -760,25 +759,13 @@ def collate_fn(examples):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
- base_weight = (
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
- if noise_scheduler.config.prediction_type == "v_prediction":
- # Velocity objective needs to be floored to an SNR weight of one.
- mse_loss_weights = base_weight + 1
- else:
- # Epsilon and sample both use the same loss weights.
- mse_loss_weights = base_weight
-
- # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
- # When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
- # If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
- mse_loss_weights[snr == 0] = 1.0
-
- # We first calculate the original loss. Then we mean over the non-batch dimensions and
- # rebalance the sample-wise losses with their respective loss weights.
- # Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py
index 991f8a84a243..ed7a15cd95fe 100644
--- a/examples/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py
@@ -947,18 +947,25 @@ def collate_fn(examples):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
+ initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
- resume_global_step = global_step * args.gradient_accumulation_steps
+ initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
- resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
- # Only show the progress bar once on each machine.
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
- progress_bar.set_description("Steps")
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
@@ -967,12 +974,6 @@ def collate_fn(examples):
text_encoder_two.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
- # Skip steps until we reach the resumed step
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
- if step % args.gradient_accumulation_steps == 0:
- progress_bar.update(1)
- continue
-
with accelerator.accumulate(unet):
# Convert images to latent space
if args.pretrained_vae_model_name_or_path is not None:
@@ -1049,25 +1050,13 @@ def compute_time_ids(original_size, crops_coords_top_left):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
- base_weight = (
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
- if noise_scheduler.config.prediction_type == "v_prediction":
- # Velocity objective needs to be floored to an SNR weight of one.
- mse_loss_weights = base_weight + 1
- else:
- # Epsilon and sample both use the same loss weights.
- mse_loss_weights = base_weight
-
- # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
- # When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
- # If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
- mse_loss_weights[snr == 0] = 1.0
-
- # We first calculate the original loss. Then we mean over the non-batch dimensions and
- # rebalance the sample-wise losses with their respective loss weights.
- # Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py
index 649b82ed3baa..c681943f2e94 100644
--- a/examples/text_to_image/train_text_to_image_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_sdxl.py
@@ -657,6 +657,8 @@ def main(args):
vae.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
+ # Set unet as trainable.
+ unet.train()
# For mixed precision training we cast all non-trainable weigths to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
@@ -967,29 +969,29 @@ def collate_fn(examples):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
+ initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
- resume_global_step = global_step * args.gradient_accumulation_steps
+ initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
- resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
- # Only show the progress bar once on each machine.
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
- progress_bar.set_description("Steps")
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
for epoch in range(first_epoch, args.num_train_epochs):
- unet.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
- # Skip steps until we reach the resumed step
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
- if step % args.gradient_accumulation_steps == 0:
- progress_bar.update(1)
- continue
-
with accelerator.accumulate(unet):
# Sample noise that we'll add to the latents
model_input = batch["model_input"].to(accelerator.device)
@@ -1065,25 +1067,13 @@ def compute_time_ids(original_size, crops_coords_top_left):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
- base_weight = (
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
- if noise_scheduler.config.prediction_type == "v_prediction":
- # Velocity objective needs to be floored to an SNR weight of one.
- mse_loss_weights = base_weight + 1
- else:
- # Epsilon and sample both use the same loss weights.
- mse_loss_weights = base_weight
-
- # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
- # When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
- # If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
- mse_loss_weights[snr == 0] = 1.0
-
- # We first calculate the original loss. Then we mean over the non-batch dimensions and
- # rebalance the sample-wise losses with their respective loss weights.
- # Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py
index 2e6f9a7d9522..01830751ffe2 100644
--- a/examples/textual_inversion/textual_inversion.py
+++ b/examples/textual_inversion/textual_inversion.py
@@ -809,18 +809,25 @@ def main():
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
+ initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
- resume_global_step = global_step * args.gradient_accumulation_steps
+ initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
- resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
- # Only show the progress bar once on each machine.
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
- progress_bar.set_description("Steps")
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
# keep original embeddings as reference
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
@@ -828,12 +835,6 @@ def main():
for epoch in range(first_epoch, args.num_train_epochs):
text_encoder.train()
for step, batch in enumerate(train_dataloader):
- # Skip steps until we reach the resumed step
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
- if step % args.gradient_accumulation_steps == 0:
- progress_bar.update(1)
- continue
-
with accelerator.accumulate(text_encoder):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py
index 4925c74c8ccf..a3baa3b85b36 100644
--- a/examples/unconditional_image_generation/train_unconditional.py
+++ b/examples/unconditional_image_generation/train_unconditional.py
@@ -607,28 +607,28 @@ def transform_images(examples):
progress_bar.update(1)
global_step += 1
- if global_step % args.checkpointing_steps == 0:
- # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
- if args.checkpoints_total_limit is not None:
- checkpoints = os.listdir(args.output_dir)
- checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
- checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
-
- # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
- if len(checkpoints) >= args.checkpoints_total_limit:
- num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
- removing_checkpoints = checkpoints[0:num_to_remove]
-
- logger.info(
- f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
- )
- logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
-
- for removing_checkpoint in removing_checkpoints:
- removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
- shutil.rmtree(removing_checkpoint)
-
- if accelerator.is_main_process:
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
diff --git a/scripts/convert_unidiffuser_to_diffusers.py b/scripts/convert_unidiffuser_to_diffusers.py
index 891d289d8c76..4c38172754f6 100644
--- a/scripts/convert_unidiffuser_to_diffusers.py
+++ b/scripts/convert_unidiffuser_to_diffusers.py
@@ -73,17 +73,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")
- new_item = new_item.replace("q.weight", "query.weight")
- new_item = new_item.replace("q.bias", "query.bias")
+ new_item = new_item.replace("q.weight", "to_q.weight")
+ new_item = new_item.replace("q.bias", "to_q.bias")
- new_item = new_item.replace("k.weight", "key.weight")
- new_item = new_item.replace("k.bias", "key.bias")
+ new_item = new_item.replace("k.weight", "to_k.weight")
+ new_item = new_item.replace("k.bias", "to_k.bias")
- new_item = new_item.replace("v.weight", "value.weight")
- new_item = new_item.replace("v.bias", "value.bias")
+ new_item = new_item.replace("v.weight", "to_v.weight")
+ new_item = new_item.replace("v.bias", "to_v.bias")
- new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
- new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
@@ -92,6 +92,19 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
return mapping
+# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
+def conv_attn_to_linear(checkpoint):
+ keys = list(checkpoint.keys())
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
+ for key in keys:
+ if ".".join(key.split(".")[-2:]) in attn_keys:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
+ elif "proj_attn.weight" in key:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0]
+
+
# Modified from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint
# config.num_head_channels => num_head_channels
def assign_to_checkpoint(
@@ -104,8 +117,9 @@ def assign_to_checkpoint(
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
- attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new
- checkpoint.
+ attention layers, and takes into account additional replacements that may arise.
+
+ Assigns the weights to the new checkpoint.
"""
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
@@ -143,25 +157,16 @@ def assign_to_checkpoint(
new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear
- if "proj_attn.weight" in new_path:
+ is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
+ shape = old_checkpoint[path["old"]].shape
+ if is_attn_weight and len(shape) == 3:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
+ elif is_attn_weight and len(shape) == 4:
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
else:
checkpoint[new_path] = old_checkpoint[path["old"]]
-# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
-def conv_attn_to_linear(checkpoint):
- keys = list(checkpoint.keys())
- attn_keys = ["query.weight", "key.weight", "value.weight"]
- for key in keys:
- if ".".join(key.split(".")[-2:]) in attn_keys:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
- elif "proj_attn.weight" in key:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0]
-
-
def create_vae_diffusers_config(config_type):
# Hardcoded for now
if args.config_type == "test":
@@ -339,7 +344,7 @@ def create_text_decoder_config_big():
return text_decoder_config
-# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments.convert_ldm_vae_checkpoint
+# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint
def convert_vae_to_diffusers(ckpt, diffusers_model, num_head_channels=1):
"""
Converts a UniDiffuser autoencoder_kl.pth checkpoint to a diffusers AutoencoderKL.
@@ -674,6 +679,11 @@ def convert_caption_decoder_to_diffusers(ckpt, diffusers_model):
type=int,
help="The UniDiffuser model type to convert to. Should be 0 for UniDiffuser-v0 and 1 for UniDiffuser-v1.",
)
+ parser.add_argument(
+ "--safe_serialization",
+ action="store_true",
+ help="Whether to use safetensors/safe seialization when saving the pipeline.",
+ )
args = parser.parse_args()
@@ -766,11 +776,11 @@ def convert_caption_decoder_to_diffusers(ckpt, diffusers_model):
vae=vae,
text_encoder=text_encoder,
image_encoder=image_encoder,
- image_processor=image_processor,
+ clip_image_processor=image_processor,
clip_tokenizer=clip_tokenizer,
text_decoder=text_decoder,
text_tokenizer=text_tokenizer,
unet=unet,
scheduler=scheduler,
)
- pipeline.save_pretrained(args.pipeline_output_path)
+ pipeline.save_pretrained(args.pipeline_output_path, safe_serialization=args.safe_serialization)
diff --git a/setup.py b/setup.py
index a2201ac5b3b1..7ad5646d4fca 100644
--- a/setup.py
+++ b/setup.py
@@ -102,8 +102,8 @@
"importlib_metadata",
"invisible-watermark>=0.2.0",
"isort>=5.5.4",
- "jax>=0.2.8,!=0.3.2",
- "jaxlib>=0.1.65",
+ "jax>=0.4.1",
+ "jaxlib>=0.4.1",
"Jinja2",
"k-diffusion>=0.0.12",
"torchsde",
@@ -255,6 +255,7 @@ def run(self):
url="https://github.com/huggingface/diffusers",
package_dir={"": "src"},
packages=find_packages("src"),
+ package_data={"diffusers": ["py.typed"]},
include_package_data=True,
python_requires=">=3.8.0",
install_requires=list(install_requires),
diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py
index d4b94ba6d4ed..970013c31a20 100644
--- a/src/diffusers/dependency_versions_table.py
+++ b/src/diffusers/dependency_versions_table.py
@@ -15,8 +15,8 @@
"importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0",
"isort": "isort>=5.5.4",
- "jax": "jax>=0.2.8,!=0.3.2",
- "jaxlib": "jaxlib>=0.1.65",
+ "jax": "jax>=0.4.1",
+ "jaxlib": "jaxlib>=0.4.1",
"Jinja2": "Jinja2",
"k-diffusion": "k-diffusion>=0.0.12",
"torchsde": "torchsde",
diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py
index 11e600e59884..2cc547be0178 100644
--- a/src/diffusers/loaders.py
+++ b/src/diffusers/loaders.py
@@ -27,6 +27,7 @@
from packaging import version
from torch import nn
+from . import __version__
from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from .utils import (
DIFFUSERS_CACHE,
@@ -120,7 +121,7 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
- def _fuse_lora(self, lora_scale=1.0):
+ def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
if self.lora_linear_layer is None:
return
@@ -134,6 +135,14 @@ def _fuse_lora(self, lora_scale=1.0):
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
+
+ if safe_fusing and torch.isnan(fused_weight).any().item():
+ raise ValueError(
+ "This LoRA weight seems to be broken. "
+ f"Encountered NaN values when trying to fuse LoRA weights for {self}."
+ "LoRA weights will not be fused."
+ )
+
self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)
# we can drop the lora layer now
@@ -671,13 +680,14 @@ def save_function(weights, filename):
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
- def fuse_lora(self, lora_scale=1.0):
+ def fuse_lora(self, lora_scale=1.0, safe_fusing=False):
self.lora_scale = lora_scale
+ self._safe_fusing = safe_fusing
self.apply(self._fuse_lora_apply)
def _fuse_lora_apply(self, module):
if hasattr(module, "_fuse_lora"):
- module._fuse_lora(self.lora_scale)
+ module._fuse_lora(self.lora_scale, self._safe_fusing)
def unfuse_lora(self):
self.apply(self._unfuse_lora_apply)
@@ -1708,7 +1718,8 @@ def _remove_text_encoder_monkey_patch(self):
@classmethod
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
- deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.23", LORA_DEPRECATION_MESSAGE)
+ if version.parse(__version__) > version.parse("0.23"):
+ deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.25", LORA_DEPRECATION_MESSAGE)
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
@@ -1736,7 +1747,8 @@ def _modify_text_encoder(
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
"""
- deprecate("_modify_text_encoder", "0.23", LORA_DEPRECATION_MESSAGE)
+ if version.parse(__version__) > version.parse("0.23"):
+ deprecate("_modify_text_encoder", "0.25", LORA_DEPRECATION_MESSAGE)
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
@@ -2083,7 +2095,13 @@ def unload_lora_weights(self):
# Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch()
- def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0):
+ def fuse_lora(
+ self,
+ fuse_unet: bool = True,
+ fuse_text_encoder: bool = True,
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ ):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
@@ -2100,6 +2118,8 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora
LoRA parameters then it won't have any effect.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
"""
if fuse_unet or fuse_text_encoder:
self.num_fused_loras += 1
@@ -2109,12 +2129,13 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora
)
if fuse_unet:
- self.unet.fuse_lora(lora_scale)
+ self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
if self.use_peft_backend:
from peft.tuners.tuners_utils import BaseTunerLayer
- def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
+ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
+ # TODO(Patrick, Younes): enable "safe" fusing
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
if lora_scale != 1.0:
@@ -2123,26 +2144,27 @@ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
module.merge()
else:
- deprecate("fuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
+ if version.parse(__version__) > version.parse("0.23"):
+ deprecate("fuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
- def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
+ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
- attn_module.q_proj._fuse_lora(lora_scale)
- attn_module.k_proj._fuse_lora(lora_scale)
- attn_module.v_proj._fuse_lora(lora_scale)
- attn_module.out_proj._fuse_lora(lora_scale)
+ attn_module.q_proj._fuse_lora(lora_scale, safe_fusing)
+ attn_module.k_proj._fuse_lora(lora_scale, safe_fusing)
+ attn_module.v_proj._fuse_lora(lora_scale, safe_fusing)
+ attn_module.out_proj._fuse_lora(lora_scale, safe_fusing)
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
- mlp_module.fc1._fuse_lora(lora_scale)
- mlp_module.fc2._fuse_lora(lora_scale)
+ mlp_module.fc1._fuse_lora(lora_scale, safe_fusing)
+ mlp_module.fc2._fuse_lora(lora_scale, safe_fusing)
if fuse_text_encoder:
if hasattr(self, "text_encoder"):
- fuse_text_encoder_lora(self.text_encoder, lora_scale)
+ fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing)
if hasattr(self, "text_encoder_2"):
- fuse_text_encoder_lora(self.text_encoder_2, lora_scale)
+ fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing)
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
r"""
@@ -2173,7 +2195,8 @@ def unfuse_text_encoder_lora(text_encoder):
module.unmerge()
else:
- deprecate("unfuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
+ if version.parse(__version__) > version.parse("0.23"):
+ deprecate("unfuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
def unfuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
@@ -2428,8 +2451,12 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
from .models.controlnet import ControlNetModel
from .pipelines.controlnet.multicontrolnet import MultiControlNetModel
- # Model type will be inferred from the checkpoint.
- if not isinstance(controlnet, (ControlNetModel, MultiControlNetModel)):
+ # list/tuple or a single instance of ControlNetModel or MultiControlNetModel
+ if not (
+ isinstance(controlnet, (ControlNetModel, MultiControlNetModel))
+ or isinstance(controlnet, (list, tuple))
+ and isinstance(controlnet[0], ControlNetModel)
+ ):
raise ValueError("ControlNet needs to be passed if loading from ControlNet pipeline.")
elif "StableDiffusion" in pipeline_name:
# Model type will be inferred from the checkpoint.
diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py
index 04c978403f41..46da899096c2 100644
--- a/src/diffusers/models/activations.py
+++ b/src/diffusers/models/activations.py
@@ -1,7 +1,15 @@
from torch import nn
-def get_activation(act_fn):
+def get_activation(act_fn: str) -> nn.Module:
+ """Helper function to get activation function from string.
+
+ Args:
+ act_fn (str): Name of activation function.
+
+ Returns:
+ nn.Module: Activation function.
+ """
if act_fn in ["swish", "silu"]:
return nn.SiLU()
elif act_fn == "mish":
diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py
index a5e7d0af92d1..bf6803c565fe 100644
--- a/src/diffusers/models/adapter.py
+++ b/src/diffusers/models/adapter.py
@@ -258,6 +258,12 @@ def __init__(
)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ r"""
+ This function processes the input tensor `x` through the adapter model and returns a list of feature tensors,
+ each representing information extracted at a different scale from the input. The length of the list is
+ determined by the number of downsample blocks in the Adapter, as specified by the `channels` and
+ `num_res_blocks` parameters during initialization.
+ """
return self.adapter(x)
@property
@@ -296,6 +302,12 @@ def __init__(
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ r"""
+ This method processes the input tensor `x` through the FullAdapter model and performs operations including
+ pixel unshuffling, convolution, and a stack of AdapterBlocks. It returns a list of feature tensors, each
+ capturing information at a different stage of processing within the FullAdapter model. The number of feature
+ tensors in the list is determined by the number of downsample blocks specified during initialization.
+ """
x = self.unshuffle(x)
x = self.conv_in(x)
@@ -338,6 +350,10 @@ def __init__(
self.total_downscale_factor = downscale_factor * 2
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ r"""
+ This method takes the tensor x as input and processes it through FullAdapterXL model. It consists of operations
+ including unshuffling pixels, applying convolution layer and appending each block into list of feature tensors.
+ """
x = self.unshuffle(x)
x = self.conv_in(x)
@@ -367,6 +383,11 @@ def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
)
def forward(self, x):
+ r"""
+ This method takes tensor x as input and performs operations downsampling and convolutional layers if the
+ self.downsample and self.in_conv properties of AdapterBlock model are specified. Then it applies a series of
+ residual blocks to the input tensor.
+ """
if self.downsample is not None:
x = self.downsample(x)
@@ -386,6 +407,10 @@ def __init__(self, channels):
self.block2 = nn.Conv2d(channels, channels, kernel_size=1)
def forward(self, x):
+ r"""
+ This method takes input tensor x and applies a convolutional layer, ReLU activation, and another convolutional
+ layer on the input tensor. It returns addition with the input tensor.
+ """
h = x
h = self.block1(h)
h = self.act(h)
@@ -425,6 +450,10 @@ def __init__(
self.total_downscale_factor = downscale_factor * (2 ** len(channels))
def forward(self, x):
+ r"""
+ This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each
+ feature tensor corresponds to a different level of processing within the LightAdapter.
+ """
x = self.unshuffle(x)
features = []
@@ -450,6 +479,10 @@ def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
def forward(self, x):
+ r"""
+ This method takes tensor x as input and performs downsampling if required. Then it applies in convolution
+ layer, a sequence of residual blocks, and out convolutional layer.
+ """
if self.downsample is not None:
x = self.downsample(x)
@@ -468,6 +501,10 @@ def __init__(self, channels):
self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
def forward(self, x):
+ r"""
+ This function takes input tensor x and processes it through one convolutional layer, ReLU activation, and
+ another convolutional layer and adds it to input tensor.
+ """
h = x
h = self.block1(h)
h = self.act(h)
diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py
index 21c8f64fd916..80d2cccd536d 100644
--- a/src/diffusers/models/autoencoder_kl.py
+++ b/src/diffusers/models/autoencoder_kl.py
@@ -249,7 +249,21 @@ def set_default_attn_processor(self):
self.set_attn_processor(processor, _remove_lora=True)
@apply_forward_hook
- def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
+ def encode(
+ self, x: torch.FloatTensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ """
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.FloatTensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded images. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
return self.tiled_encode(x, return_dict=return_dict)
@@ -281,6 +295,20 @@ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decod
@apply_forward_hook
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ """
+ Decode a batch of images.
+
+ Args:
+ z (`torch.FloatTensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+
+ """
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py
index 1a82b0421f88..c0d2da9b8c5f 100644
--- a/src/diffusers/models/controlnet.py
+++ b/src/diffusers/models/controlnet.py
@@ -671,7 +671,13 @@ def forward(
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
+ embeddings.
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
added_cond_kwargs (`dict`):
Additional conditions for the Stable Diffusion XL UNet.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
diff --git a/src/diffusers/models/controlnet_flax.py b/src/diffusers/models/controlnet_flax.py
index a826df48e41a..076e6183211b 100644
--- a/src/diffusers/models/controlnet_flax.py
+++ b/src/diffusers/models/controlnet_flax.py
@@ -168,7 +168,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
controlnet_conditioning_channel_order: str = "rgb"
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
- def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
+ def init_weights(self, rng: jax.Array) -> FrozenDict:
# init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
diff --git a/src/diffusers/models/dual_transformer_2d.py b/src/diffusers/models/dual_transformer_2d.py
index 3db7e73ca6af..02568298409c 100644
--- a/src/diffusers/models/dual_transformer_2d.py
+++ b/src/diffusers/models/dual_transformer_2d.py
@@ -107,14 +107,18 @@ def forward(
Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
- hidden_states
+ hidden_states.
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
attention_mask (`torch.FloatTensor`, *optional*):
- Optional attention mask to be applied in Attention
+ Optional attention mask to be applied in Attention.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py
index cc8e3e231e2b..aec7200afdfe 100644
--- a/src/diffusers/models/lora.py
+++ b/src/diffusers/models/lora.py
@@ -112,7 +112,7 @@ def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
self.lora_layer = lora_layer
- def _fuse_lora(self, lora_scale=1.0):
+ def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
if self.lora_layer is None:
return
@@ -128,6 +128,14 @@ def _fuse_lora(self, lora_scale=1.0):
fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
fusion = fusion.reshape((w_orig.shape))
fused_weight = w_orig + (lora_scale * fusion)
+
+ if safe_fusing and torch.isnan(fused_weight).any().item():
+ raise ValueError(
+ "This LoRA weight seems to be broken. "
+ f"Encountered NaN values when trying to fuse LoRA weights for {self}."
+ "LoRA weights will not be fused."
+ )
+
self.weight.data = fused_weight.to(device=device, dtype=dtype)
# we can drop the lora layer now
@@ -164,7 +172,10 @@ def forward(self, hidden_states, scale: float = 1.0):
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
)
else:
- return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
+ original_outputs = F.conv2d(
+ hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
+ )
+ return original_outputs + (scale * self.lora_layer(hidden_states))
class LoRACompatibleLinear(nn.Linear):
@@ -179,7 +190,7 @@ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
self.lora_layer = lora_layer
- def _fuse_lora(self, lora_scale=1.0):
+ def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
if self.lora_layer is None:
return
@@ -193,6 +204,14 @@ def _fuse_lora(self, lora_scale=1.0):
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
+
+ if safe_fusing and torch.isnan(fused_weight).any().item():
+ raise ValueError(
+ "This LoRA weight seems to be broken. "
+ f"Encountered NaN values when trying to fuse LoRA weights for {self}."
+ "LoRA weights will not be fused."
+ )
+
self.weight.data = fused_weight.to(device=device, dtype=dtype)
# we can drop the lora layer now
diff --git a/src/diffusers/models/modeling_flax_utils.py b/src/diffusers/models/modeling_flax_utils.py
index 97f7b43bc64e..ea4d1bfea548 100644
--- a/src/diffusers/models/modeling_flax_utils.py
+++ b/src/diffusers/models/modeling_flax_utils.py
@@ -192,7 +192,7 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
```"""
return self._cast_floating_to(params, jnp.float16, mask)
- def init_weights(self, rng: jax.random.KeyArray) -> Dict:
+ def init_weights(self, rng: jax.Array) -> Dict:
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
@classmethod
diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py
index ac66e2271c61..3972b438b076 100644
--- a/src/diffusers/models/resnet.py
+++ b/src/diffusers/models/resnet.py
@@ -14,7 +14,7 @@
# limitations under the License.
from functools import partial
-from typing import Optional
+from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -38,9 +38,18 @@ class Upsample1D(nn.Module):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
+ name (`str`, default `conv`):
+ name of the upsampling 1D layer.
"""
- def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
+ def __init__(
+ self,
+ channels: int,
+ use_conv: bool = False,
+ use_conv_transpose: bool = False,
+ out_channels: Optional[int] = None,
+ name: str = "conv",
+ ):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -54,7 +63,7 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
- def forward(self, inputs):
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
assert inputs.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(inputs)
@@ -79,9 +88,18 @@ class Downsample1D(nn.Module):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
+ name (`str`, default `conv`):
+ name of the downsampling 1D layer.
"""
- def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
+ def __init__(
+ self,
+ channels: int,
+ use_conv: bool = False,
+ out_channels: Optional[int] = None,
+ padding: int = 1,
+ name: str = "conv",
+ ):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -96,7 +114,7 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name=
assert self.channels == self.out_channels
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
- def forward(self, inputs):
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
assert inputs.shape[1] == self.channels
return self.conv(inputs)
@@ -113,9 +131,18 @@ class Upsample2D(nn.Module):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
+ name (`str`, default `conv`):
+ name of the upsampling 2D layer.
"""
- def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
+ def __init__(
+ self,
+ channels: int,
+ use_conv: bool = False,
+ use_conv_transpose: bool = False,
+ out_channels: Optional[int] = None,
+ name: str = "conv",
+ ):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -135,7 +162,7 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann
else:
self.Conv2d_0 = conv
- def forward(self, hidden_states, output_size=None, scale: float = 1.0):
+ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, scale: float = 1.0):
assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose:
@@ -191,9 +218,18 @@ class Downsample2D(nn.Module):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
+ name (`str`, default `conv`):
+ name of the downsampling 2D layer.
"""
- def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
+ def __init__(
+ self,
+ channels: int,
+ use_conv: bool = False,
+ out_channels: Optional[int] = None,
+ padding: int = 1,
+ name: str = "conv",
+ ):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -246,7 +282,13 @@ class FirUpsample2D(nn.Module):
kernel for the FIR filter.
"""
- def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
+ def __init__(
+ self,
+ channels: int = None,
+ out_channels: Optional[int] = None,
+ use_conv: bool = False,
+ fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
+ ):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
@@ -255,7 +297,14 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=
self.fir_kernel = fir_kernel
self.out_channels = out_channels
- def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
+ def _upsample_2d(
+ self,
+ hidden_states: torch.Tensor,
+ weight: Optional[torch.Tensor] = None,
+ kernel: Optional[torch.FloatTensor] = None,
+ factor: int = 2,
+ gain: float = 1,
+ ) -> torch.Tensor:
"""Fused `upsample_2d()` followed by `Conv2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
@@ -335,7 +384,7 @@ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1
return output
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.use_conv:
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
@@ -359,7 +408,13 @@ class FirDownsample2D(nn.Module):
kernel for the FIR filter.
"""
- def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
+ def __init__(
+ self,
+ channels: int = None,
+ out_channels: Optional[int] = None,
+ use_conv: bool = False,
+ fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
+ ):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
@@ -368,7 +423,14 @@ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=
self.use_conv = use_conv
self.out_channels = out_channels
- def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
+ def _downsample_2d(
+ self,
+ hidden_states: torch.Tensor,
+ weight: Optional[torch.Tensor] = None,
+ kernel: Optional[torch.FloatTensor] = None,
+ factor: int = 2,
+ gain: float = 1,
+ ) -> torch.Tensor:
"""Fused `Conv2d()` followed by `downsample_2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
@@ -422,7 +484,7 @@ def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain
return output
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.use_conv:
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
@@ -434,14 +496,20 @@ def forward(self, hidden_states):
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
class KDownsample2D(nn.Module):
- def __init__(self, pad_mode="reflect"):
+ r"""A 2D K-downsampling layer.
+
+ Parameters:
+ pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
+ """
+
+ def __init__(self, pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
- def forward(self, inputs):
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(inputs.shape[1], device=inputs.device)
@@ -451,14 +519,20 @@ def forward(self, inputs):
class KUpsample2D(nn.Module):
- def __init__(self, pad_mode="reflect"):
+ r"""A 2D K-upsampling layer.
+
+ Parameters:
+ pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
+ """
+
+ def __init__(self, pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
- def forward(self, inputs):
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(inputs.shape[1], device=inputs.device)
@@ -501,23 +575,23 @@ class ResnetBlock2D(nn.Module):
def __init__(
self,
*,
- in_channels,
- out_channels=None,
- conv_shortcut=False,
- dropout=0.0,
- temb_channels=512,
- groups=32,
- groups_out=None,
- pre_norm=True,
- eps=1e-6,
- non_linearity="swish",
- skip_time_act=False,
- time_embedding_norm="default", # default, scale_shift, ada_group, spatial
- kernel=None,
- output_scale_factor=1.0,
- use_in_shortcut=None,
- up=False,
- down=False,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ conv_shortcut: bool = False,
+ dropout: float = 0.0,
+ temb_channels: int = 512,
+ groups: int = 32,
+ groups_out: Optional[int] = None,
+ pre_norm: bool = True,
+ eps: float = 1e-6,
+ non_linearity: str = "swish",
+ skip_time_act: bool = False,
+ time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
+ kernel: Optional[torch.FloatTensor] = None,
+ output_scale_factor: float = 1.0,
+ use_in_shortcut: Optional[bool] = None,
+ up: bool = False,
+ down: bool = False,
conv_shortcut_bias: bool = True,
conv_2d_out_channels: Optional[int] = None,
):
@@ -667,7 +741,7 @@ def forward(self, input_tensor, temb, scale: float = 1.0):
# unet_rl.py
-def rearrange_dims(tensor):
+def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor:
if len(tensor.shape) == 2:
return tensor[:, :, None]
if len(tensor.shape) == 3:
@@ -681,16 +755,24 @@ def rearrange_dims(tensor):
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
+
+ Parameters:
+ inp_channels (`int`): Number of input channels.
+ out_channels (`int`): Number of output channels.
+ kernel_size (`int` or `tuple`): Size of the convolving kernel.
+ n_groups (`int`, default `8`): Number of groups to separate the channels into.
"""
- def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
+ def __init__(
+ self, inp_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], n_groups: int = 8
+ ):
super().__init__()
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.group_norm = nn.GroupNorm(n_groups, out_channels)
self.mish = nn.Mish()
- def forward(self, inputs):
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
intermediate_repr = self.conv1d(inputs)
intermediate_repr = rearrange_dims(intermediate_repr)
intermediate_repr = self.group_norm(intermediate_repr)
@@ -701,7 +783,19 @@ def forward(self, inputs):
# unet_rl.py
class ResidualTemporalBlock1D(nn.Module):
- def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
+ """
+ Residual 1D block with temporal convolutions.
+
+ Parameters:
+ inp_channels (`int`): Number of input channels.
+ out_channels (`int`): Number of output channels.
+ embed_dim (`int`): Embedding dimension.
+ kernel_size (`int` or `tuple`): Size of the convolving kernel.
+ """
+
+ def __init__(
+ self, inp_channels: int, out_channels: int, embed_dim: int, kernel_size: Union[int, Tuple[int, int]] = 5
+ ):
super().__init__()
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
@@ -713,7 +807,7 @@ def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
- def forward(self, inputs, t):
+ def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Args:
inputs : [ batch_size x inp_channels x horizon ]
@@ -729,7 +823,9 @@ def forward(self, inputs, t):
return out + self.residual_conv(inputs)
-def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
+def upsample_2d(
+ hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
+) -> torch.Tensor:
r"""Upsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
@@ -766,7 +862,9 @@ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
return output
-def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
+def downsample_2d(
+ hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
+) -> torch.Tensor:
r"""Downsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
@@ -801,7 +899,9 @@ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
return output
-def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
+def upfirdn2d_native(
+ tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0)
+) -> torch.Tensor:
up_x = up_y = up
down_x = down_y = down
pad_x0 = pad_y0 = pad[0]
@@ -849,9 +949,14 @@ class TemporalConvLayer(nn.Module):
"""
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
+
+ Parameters:
+ in_dim (`int`): Number of input channels.
+ out_dim (`int`): Number of output channels.
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
"""
- def __init__(self, in_dim, out_dim=None, dropout=0.0):
+ def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0):
super().__init__()
out_dim = out_dim or in_dim
self.in_dim = in_dim
@@ -884,7 +989,7 @@ def __init__(self, in_dim, out_dim=None, dropout=0.0):
nn.init.zeros_(self.conv4[-1].weight)
nn.init.zeros_(self.conv4[-1].bias)
- def forward(self, hidden_states, num_frames=1):
+ def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor:
hidden_states = (
hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
)
diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py
index c96aef65f339..e7780a7bca3d 100644
--- a/src/diffusers/models/transformer_2d.py
+++ b/src/diffusers/models/transformer_2d.py
@@ -235,6 +235,14 @@ def forward(
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ attention_mask ( `torch.Tensor`, *optional*):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py
index cfafdb055bcf..d002cb3315fa 100644
--- a/src/diffusers/models/transformer_temporal.py
+++ b/src/diffusers/models/transformer_temporal.py
@@ -128,6 +128,12 @@ def forward(
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
+ num_frames (`int`, *optional*, defaults to 1):
+ The number of frames to be processed per batch. This is used to reshape the hidden states.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py
index 8aebb3aad615..18b0e41af738 100644
--- a/src/diffusers/models/unet_2d_blocks.py
+++ b/src/diffusers/models/unet_2d_blocks.py
@@ -19,6 +19,7 @@
from torch import nn
from ..utils import is_torch_version, logging
+from ..utils.torch_utils import apply_freeu
from .activations import get_activation
from .attention import AdaGroupNorm
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
@@ -249,6 +250,7 @@ def get_up_block(
add_upsample,
resnet_eps,
resnet_act_fn,
+ resolution_idx=None,
transformer_layers_per_block=1,
num_attention_heads=None,
resnet_groups=None,
@@ -281,6 +283,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -295,6 +298,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -314,6 +318,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -337,6 +342,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -362,6 +368,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
dropout=dropout,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
@@ -377,6 +384,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -390,6 +398,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -402,6 +411,7 @@ def get_up_block(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
+ resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -415,6 +425,7 @@ def get_up_block(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
+ resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -430,6 +441,7 @@ def get_up_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -441,6 +453,7 @@ def get_up_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
@@ -453,6 +466,18 @@ def get_up_block(
class AutoencoderTinyBlock(nn.Module):
+ """
+ Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU blocks.
+
+ Args:
+ in_channels (`int`): The number of input channels.
+ out_channels (`int`): The number of output channels.
+ act_fn (`str`):` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
+
+ Returns:
+ `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to `out_channels`.
+ """
+
def __init__(self, in_channels: int, out_channels: int, act_fn: str):
super().__init__()
act_fn = get_activation(act_fn)
@@ -1993,6 +2018,7 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2075,6 +2101,8 @@ def __init__(
else:
self.upsamplers = None
+ self.resolution_idx = resolution_idx
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
@@ -2103,6 +2131,7 @@ def __init__(
out_channels: int,
prev_output_channel: int,
temb_channels: int,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
@@ -2181,6 +2210,7 @@ def __init__(
self.upsamplers = None
self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
def forward(
self,
@@ -2194,11 +2224,30 @@ def forward(
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
@@ -2252,6 +2301,7 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2292,12 +2342,33 @@ def __init__(
self.upsamplers = None
self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
@@ -2331,6 +2402,7 @@ def __init__(
self,
in_channels: int,
out_channels: int,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2370,6 +2442,8 @@ def __init__(
else:
self.upsamplers = None
+ self.resolution_idx = resolution_idx
+
def forward(self, hidden_states, temb=None, scale: float = 1.0):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
@@ -2386,6 +2460,7 @@ def __init__(
self,
in_channels: int,
out_channels: int,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2449,6 +2524,8 @@ def __init__(
else:
self.upsamplers = None
+ self.resolution_idx = resolution_idx
+
def forward(self, hidden_states, temb=None, scale: float = 1.0):
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
@@ -2469,6 +2546,7 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2553,6 +2631,8 @@ def __init__(
self.skip_norm = None
self.act = None
+ self.resolution_idx = resolution_idx
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0):
for resnet in self.resnets:
# pop res hidden states
@@ -2589,6 +2669,7 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2651,6 +2732,8 @@ def __init__(
self.skip_norm = None
self.act = None
+ self.resolution_idx = resolution_idx
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0):
for resnet in self.resnets:
# pop res hidden states
@@ -2684,6 +2767,7 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2743,6 +2827,7 @@ def __init__(
self.upsamplers = None
self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
for resnet in self.resnets:
@@ -2784,6 +2869,7 @@ def __init__(
out_channels: int,
prev_output_channel: int,
temb_channels: int,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2873,6 +2959,7 @@ def __init__(
self.upsamplers = None
self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
def forward(
self,
@@ -2947,6 +3034,7 @@ def __init__(
in_channels: int,
out_channels: int,
temb_channels: int,
+ resolution_idx: int,
dropout: float = 0.0,
num_layers: int = 5,
resnet_eps: float = 1e-5,
@@ -2988,6 +3076,7 @@ def __init__(
self.upsamplers = None
self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
res_hidden_states_tuple = res_hidden_states_tuple[-1]
@@ -3027,6 +3116,7 @@ def __init__(
in_channels: int,
out_channels: int,
temb_channels: int,
+ resolution_idx: int,
dropout: float = 0.0,
num_layers: int = 4,
resnet_eps: float = 1e-5,
@@ -3104,6 +3194,7 @@ def __init__(
self.upsamplers = None
self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
def forward(
self,
diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py
index 866254a89545..f858a7685360 100644
--- a/src/diffusers/models/unet_2d_condition.py
+++ b/src/diffusers/models/unet_2d_condition.py
@@ -542,6 +542,7 @@ def __init__(
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
+ resolution_idx=i,
resnet_groups=norm_num_groups,
cross_attention_dim=reversed_cross_attention_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
@@ -733,6 +734,38 @@ def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
+ def enable_freeu(self, s1, s2, b1, b2):
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ for i, upsample_block in enumerate(self.up_blocks):
+ setattr(upsample_block, "s1", s1)
+ setattr(upsample_block, "s2", s2)
+ setattr(upsample_block, "b1", b1)
+ setattr(upsample_block, "b2", b2)
+
+ def disable_freeu(self):
+ """Disables the FreeU mechanism."""
+ freeu_keys = {"s1", "s2", "b1", "b2"}
+ for i, upsample_block in enumerate(self.up_blocks):
+ for k in freeu_keys:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None:
+ setattr(upsample_block, k, None)
+
def forward(
self,
sample: torch.FloatTensor,
@@ -757,6 +790,26 @@ def forward(
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.FloatTensor`):
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
+ A tensor that if specified is added to the residual of the middle unet block.
encoder_attention_mask (`torch.Tensor`):
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py
index a3aebde7bf16..770cbf09ccac 100644
--- a/src/diffusers/models/unet_2d_condition_flax.py
+++ b/src/diffusers/models/unet_2d_condition_flax.py
@@ -126,7 +126,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
addition_embed_type_num_heads: int = 64
projection_class_embeddings_input_dim: Optional[int] = None
- def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
+ def init_weights(self, rng: jax.Array) -> FrozenDict:
# init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
@@ -334,6 +334,13 @@ def __call__(
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
timestep (`jnp.ndarray` or `float` or `int`): timesteps
encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
+ A tensor that if specified is added to the residual of the middle unet block.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
plain tuple.
diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py
index ab5c393518e2..180ae0dc1a81 100644
--- a/src/diffusers/models/unet_3d_blocks.py
+++ b/src/diffusers/models/unet_3d_blocks.py
@@ -15,6 +15,7 @@
import torch
from torch import nn
+from ..utils.torch_utils import apply_freeu
from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
from .transformer_2d import Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
@@ -87,6 +88,7 @@ def get_up_block(
resnet_eps,
resnet_act_fn,
num_attention_heads,
+ resolution_idx=None,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=False,
@@ -107,6 +109,7 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
+ resolution_idx=resolution_idx,
)
elif up_block_type == "CrossAttnUpBlock3D":
if cross_attention_dim is None:
@@ -128,6 +131,7 @@ def get_up_block(
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
+ resolution_idx=resolution_idx,
)
raise ValueError(f"{up_block_type} does not exist.")
@@ -496,6 +500,7 @@ def __init__(
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
+ resolution_idx=None,
):
super().__init__()
resnets = []
@@ -565,6 +570,7 @@ def __init__(
self.upsamplers = None
self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
def forward(
self,
@@ -577,6 +583,13 @@ def forward(
num_frames=1,
cross_attention_kwargs=None,
):
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+
# TODO(Patrick, William) - attention mask is not used
for resnet, temp_conv, attn, temp_attn in zip(
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
@@ -584,6 +597,19 @@ def forward(
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
@@ -621,6 +647,7 @@ def __init__(
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_upsample=True,
+ resolution_idx=None,
):
super().__init__()
resnets = []
@@ -661,12 +688,32 @@ def __init__(
self.upsamplers = None
self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py
index 01af31061d10..2ab1d4060e17 100644
--- a/src/diffusers/models/unet_3d_condition.py
+++ b/src/diffusers/models/unet_3d_condition.py
@@ -255,6 +255,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=False,
+ resolution_idx=i,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
@@ -462,6 +463,40 @@ def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
+ def enable_freeu(self, s1, s2, b1, b2):
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ for i, upsample_block in enumerate(self.up_blocks):
+ setattr(upsample_block, "s1", s1)
+ setattr(upsample_block, "s2", s2)
+ setattr(upsample_block, "b1", b1)
+ setattr(upsample_block, "b2", b2)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
+ def disable_freeu(self):
+ """Disables the FreeU mechanism."""
+ freeu_keys = {"s1", "s2", "b1", "b2"}
+ for i, upsample_block in enumerate(self.up_blocks):
+ for k in freeu_keys:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None:
+ setattr(upsample_block, k, None)
+
def forward(
self,
sample: torch.FloatTensor,
@@ -484,6 +519,23 @@ def forward(
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.FloatTensor`):
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
+ A tensor that if specified is added to the residual of the middle unet block.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
tuple.
diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py
index b8f5b1d0e399..d2dde2ba197b 100644
--- a/src/diffusers/models/vae_flax.py
+++ b/src/diffusers/models/vae_flax.py
@@ -817,7 +817,7 @@ def setup(self):
dtype=self.dtype,
)
- def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
+ def init_weights(self, rng: jax.Array) -> FrozenDict:
# init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
index 87eb52eef3c0..ba3930f5da59 100644
--- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
+++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
@@ -547,6 +547,32 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
latents = latents * self.scheduler.init_noise_sigma
return latents
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stages where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Alt Diffusion v1, v2, and Alt Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -736,7 +762,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
index 562bb5f59c56..47fa019647d4 100644
--- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
+++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
@@ -765,7 +765,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/audioldm/pipeline_audioldm.py b/src/diffusers/pipelines/audioldm/pipeline_audioldm.py
index 31e09b728531..3345fb6e7586 100644
--- a/src/diffusers/pipelines/audioldm/pipeline_audioldm.py
+++ b/src/diffusers/pipelines/audioldm/pipeline_audioldm.py
@@ -542,7 +542,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# 8. Post-processing
mel_spectrogram = self.decode_latents(latents)
diff --git a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
index 31b9266060b0..b2dd9f7bb03e 100644
--- a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
+++ b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
@@ -945,7 +945,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
self.maybe_free_model_hooks()
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
index 6ffaac6800b4..ad0060976440 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
@@ -1005,7 +1005,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
index e10a8624f068..58f003960e99 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
@@ -213,7 +213,7 @@ def prepare_control_image(
do_center_crop=False,
do_normalize=False,
return_tensors="pt",
- )["pixel_values"].to(self.device)
+ )["pixel_values"].to(device)
image_batch_size = image.shape[0]
if image_batch_size == 1:
@@ -365,7 +365,7 @@ def __call__(
height=height,
batch_size=batch_size,
num_images_per_prompt=1,
- device=self.device,
+ device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
index 9c8d6e753693..ef34ad3ee70a 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
@@ -1087,7 +1087,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
index bca8bebcd60a..640ca0a22e9c 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
@@ -1351,7 +1351,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
index aa998d7e5f4c..41b0d5434386 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
@@ -1507,7 +1507,8 @@ def denoising_value_valid(dnv):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
index 708cf869b9b9..7f230c2ec058 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
@@ -1158,7 +1158,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# manually for max memory savings
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
index f8aa4a9e26c7..aeffc219674d 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
@@ -1344,7 +1344,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
diff --git a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
index b2c8871aa0d6..e1f508dc1e36 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
@@ -238,14 +238,14 @@ def _run_safety_checker(self, images, safety_model_params, jit=False):
def _generate(
self,
- prompt_ids: jnp.array,
- image: jnp.array,
+ prompt_ids: jnp.ndarray,
+ image: jnp.ndarray,
params: Union[Dict, FrozenDict],
- prng_seed: jax.random.KeyArray,
+ prng_seed: jax.Array,
num_inference_steps: int,
guidance_scale: float,
- latents: Optional[jnp.array] = None,
- neg_prompt_ids: Optional[jnp.array] = None,
+ latents: Optional[jnp.ndarray] = None,
+ neg_prompt_ids: Optional[jnp.ndarray] = None,
controlnet_conditioning_scale: float = 1.0,
):
height, width = image.shape[-2:]
@@ -348,15 +348,15 @@ def loop_body(step, args):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
- prompt_ids: jnp.array,
- image: jnp.array,
+ prompt_ids: jnp.ndarray,
+ image: jnp.ndarray,
params: Union[Dict, FrozenDict],
- prng_seed: jax.random.KeyArray,
+ prng_seed: jax.Array,
num_inference_steps: int = 50,
- guidance_scale: Union[float, jnp.array] = 7.5,
- latents: jnp.array = None,
- neg_prompt_ids: jnp.array = None,
- controlnet_conditioning_scale: Union[float, jnp.array] = 1.0,
+ guidance_scale: Union[float, jnp.ndarray] = 7.5,
+ latents: jnp.ndarray = None,
+ neg_prompt_ids: jnp.ndarray = None,
+ controlnet_conditioning_scale: Union[float, jnp.ndarray] = 1.0,
return_dict: bool = True,
jit: bool = False,
):
@@ -364,13 +364,13 @@ def __call__(
The call function to the pipeline for generation.
Args:
- prompt_ids (`jnp.array`):
+ prompt_ids (`jnp.ndarray`):
The prompt or prompts to guide the image generation.
- image (`jnp.array`):
+ image (`jnp.ndarray`):
Array representing the ControlNet input condition to provide guidance to the `unet` for generation.
params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights.
- prng_seed (`jax.random.KeyArray` or `jax.Array`):
+ prng_seed (`jax.Array`):
Array containing random number generator key.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -378,11 +378,11 @@ def __call__(
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
- latents (`jnp.array`, *optional*):
+ latents (`jnp.ndarray`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
array is generated by sampling using the supplied random `generator`.
- controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0):
+ controlnet_conditioning_scale (`float` or `jnp.ndarray`, *optional*, defaults to 1.0):
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original `unet`.
return_dict (`bool`, *optional*, defaults to `True`):
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
index a715eb784617..5c78b0dce87e 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
@@ -382,7 +382,8 @@ def __call__(
).prev_sample
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
index 3847c2eac793..a22823aadef4 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
@@ -475,7 +475,8 @@ def __call__(
).prev_sample
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# 7. post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
index 2df62829a960..144e3ce585af 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
@@ -610,7 +610,8 @@ def __call__(
).prev_sample
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
index 5d1cbb1af291..3d7b09471969 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
@@ -244,7 +244,8 @@ def __call__(
)[0]
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
index cb0465c11ef9..b6e02485bef1 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
@@ -295,7 +295,8 @@ def __call__(
)[0]
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
index 5d92d485e4b1..854b87d72f25 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
@@ -355,7 +355,8 @@ def __call__(
)[0]
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
index 160a41cbfd7a..8cf3735672a8 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
@@ -319,7 +319,8 @@ def __call__(
)[0]
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
index 68f965423926..7a9326b708e5 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
@@ -471,7 +471,8 @@ def __call__(
latents = init_mask * init_latents_proper + (1 - init_mask) * latents
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# post-processing
latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents
diff --git a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py
index 4ee07f4e056a..9e6b6fea13e5 100644
--- a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py
+++ b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py
@@ -616,7 +616,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
self.maybe_free_model_hooks()
diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
index 01796eb43079..a782caa55efc 100644
--- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
+++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
@@ -581,7 +581,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
self.maybe_free_model_hooks()
diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py
index a9aa0d35fdac..7b067405cace 100644
--- a/src/diffusers/pipelines/pipeline_flax_utils.py
+++ b/src/diffusers/pipelines/pipeline_flax_utils.py
@@ -341,8 +341,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
- # make sure we don't download PyTorch weights, unless when using from_pt
- ignore_patterns = "*.bin" if not from_pt else []
+ ignore_patterns = ["*.bin", "*.safetensors"] if not from_pt else []
+ ignore_patterns += ["*.onnx", "*.onnx_data", "*.xml", "*.pb"]
if cls != FlaxDiffusionPipeline:
requested_pipeline_class = cls.__name__
diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
index 8ceb173f3ee2..c467d5ebe829 100644
--- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
+++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
@@ -689,7 +689,8 @@ def __call__(
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# 8. Post-processing
if not output_type == "latent":
diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
index 618ee1942224..e97f66bbcb24 100644
--- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
+++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
@@ -304,8 +304,6 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
class_embed_type = "projection"
assert "adm_in_channels" in unet_params
projection_class_embeddings_input_dim = unet_params.adm_in_channels
- else:
- raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
config = {
"sample_size": image_size // vae_scale_factor,
@@ -323,6 +321,12 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
"transformer_layers_per_block": transformer_layers_per_block,
}
+ if "disable_self_attentions" in unet_params:
+ config["only_cross_attention"] = unet_params.disable_self_attentions
+
+ if "num_classes" in unet_params and type(unet_params.num_classes) == int:
+ config["num_class_embeds"] = unet_params.num_classes
+
if controlnet:
config["conditioning_channels"] = unet_params.hint_channels
else:
@@ -441,6 +445,10 @@ def convert_ldm_unet_checkpoint(
new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
+ # Relevant to StableDiffusionUpscalePipeline
+ if "num_class_embeds" in config:
+ new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
@@ -496,6 +504,7 @@ def convert_ldm_unet_checkpoint(
if len(attentions):
paths = renew_attention_paths(attentions)
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
@@ -1210,6 +1219,7 @@ def download_from_original_stable_diffusion_ckpt(
StableDiffusionControlNetPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
+ StableDiffusionUpscalePipeline,
StableDiffusionXLImg2ImgPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
@@ -1256,6 +1266,8 @@ def download_from_original_stable_diffusion_ckpt(
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
+ is_upscale = pipeline_class == StableDiffusionUpscalePipeline
+
config_url = None
# model_type = "v1"
@@ -1285,6 +1297,10 @@ def download_from_original_stable_diffusion_ckpt(
original_config_file = config_files["xl_refiner"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
+
+ if is_upscale:
+ config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
+
if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content)
@@ -1308,6 +1324,8 @@ def download_from_original_stable_diffusion_ckpt(
if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
num_in_channels = 9
+ if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
+ num_in_channels = 7
elif num_in_channels is None:
num_in_channels = 4
@@ -1391,9 +1409,13 @@ def download_from_original_stable_diffusion_ckpt(
else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
+ if pipeline_class == StableDiffusionUpscalePipeline:
+ image_size = original_config.model.params.unet_config.params.image_size
+
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention
+
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=path, extract_ema=extract_ema
@@ -1458,8 +1480,29 @@ def download_from_original_stable_diffusion_ckpt(
controlnet=controlnet,
safety_checker=None,
feature_extractor=None,
- requires_safety_checker=False,
)
+ if hasattr(pipe, "requires_safety_checker"):
+ pipe.requires_safety_checker = False
+
+ elif pipeline_class == StableDiffusionUpscalePipeline:
+ scheduler = DDIMScheduler.from_pretrained(
+ "stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler"
+ )
+ low_res_scheduler = DDPMScheduler.from_pretrained(
+ "stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler"
+ )
+
+ pipe = pipeline_class(
+ vae=vae,
+ text_encoder=text_model,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ low_res_scheduler=low_res_scheduler,
+ safety_checker=None,
+ feature_extractor=None,
+ )
+
else:
pipe = pipeline_class(
vae=vae,
@@ -1469,8 +1512,10 @@ def download_from_original_stable_diffusion_ckpt(
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
- requires_safety_checker=False,
)
+ if hasattr(pipe, "requires_safety_checker"):
+ pipe.requires_safety_checker = False
+
else:
image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
original_config, clip_stats_path=clip_stats_path, device=device
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
index 1752729e0992..6bcbbab135df 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
@@ -890,7 +890,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# 9. Post-processing
if not output_type == "latent":
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
index 131a7c7bc2bd..bcf2a6217772 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
@@ -215,13 +215,13 @@ def _generate(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
- prng_seed: jax.random.KeyArray,
+ prng_seed: jax.Array,
num_inference_steps: int,
height: int,
width: int,
guidance_scale: float,
- latents: Optional[jnp.array] = None,
- neg_prompt_ids: Optional[jnp.array] = None,
+ latents: Optional[jnp.ndarray] = None,
+ neg_prompt_ids: Optional[jnp.ndarray] = None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -312,13 +312,13 @@ def __call__(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
- prng_seed: jax.random.KeyArray,
+ prng_seed: jax.Array,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
- guidance_scale: Union[float, jnp.array] = 7.5,
- latents: jnp.array = None,
- neg_prompt_ids: jnp.array = None,
+ guidance_scale: Union[float, jnp.ndarray] = 7.5,
+ latents: jnp.ndarray = None,
+ neg_prompt_ids: jnp.ndarray = None,
return_dict: bool = True,
jit: bool = False,
):
@@ -338,7 +338,7 @@ def __call__(
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
- latents (`jnp.array`, *optional*):
+ latents (`jnp.ndarray`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
array is generated by sampling using the supplied random `generator`.
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
index a9717533fa93..c1fd310ea582 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
@@ -232,17 +232,17 @@ def get_timestep_start(self, num_inference_steps, strength):
def _generate(
self,
- prompt_ids: jnp.array,
- image: jnp.array,
+ prompt_ids: jnp.ndarray,
+ image: jnp.ndarray,
params: Union[Dict, FrozenDict],
- prng_seed: jax.random.KeyArray,
+ prng_seed: jax.Array,
start_timestep: int,
num_inference_steps: int,
height: int,
width: int,
guidance_scale: float,
- noise: Optional[jnp.array] = None,
- neg_prompt_ids: Optional[jnp.array] = None,
+ noise: Optional[jnp.ndarray] = None,
+ neg_prompt_ids: Optional[jnp.ndarray] = None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -337,17 +337,17 @@ def loop_body(step, args):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
- prompt_ids: jnp.array,
- image: jnp.array,
+ prompt_ids: jnp.ndarray,
+ image: jnp.ndarray,
params: Union[Dict, FrozenDict],
- prng_seed: jax.random.KeyArray,
+ prng_seed: jax.Array,
strength: float = 0.8,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
- guidance_scale: Union[float, jnp.array] = 7.5,
- noise: jnp.array = None,
- neg_prompt_ids: jnp.array = None,
+ guidance_scale: Union[float, jnp.ndarray] = 7.5,
+ noise: jnp.ndarray = None,
+ neg_prompt_ids: jnp.ndarray = None,
return_dict: bool = True,
jit: bool = False,
):
@@ -355,13 +355,13 @@ def __call__(
The call function to the pipeline for generation.
Args:
- prompt_ids (`jnp.array`):
+ prompt_ids (`jnp.ndarray`):
The prompt or prompts to guide image generation.
- image (`jnp.array`):
+ image (`jnp.ndarray`):
Array representing an image batch to be used as the starting point.
params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights.
- prng_seed (`jax.random.KeyArray` or `jax.Array`):
+ prng_seed (`jax.Array` or `jax.Array`):
Array containing random number generator key.
strength (`float`, *optional*, defaults to 0.8):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
@@ -379,7 +379,7 @@ def __call__(
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
- noise (`jnp.array`, *optional*):
+ noise (`jnp.ndarray`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. The array is generated by
sampling using the supplied random `generator`.
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
index b43fa3837062..b9a2331a061c 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
@@ -266,17 +266,17 @@ def _run_safety_checker(self, images, safety_model_params, jit=False):
def _generate(
self,
- prompt_ids: jnp.array,
- mask: jnp.array,
- masked_image: jnp.array,
+ prompt_ids: jnp.ndarray,
+ mask: jnp.ndarray,
+ masked_image: jnp.ndarray,
params: Union[Dict, FrozenDict],
- prng_seed: jax.random.KeyArray,
+ prng_seed: jax.Array,
num_inference_steps: int,
height: int,
width: int,
guidance_scale: float,
- latents: Optional[jnp.array] = None,
- neg_prompt_ids: Optional[jnp.array] = None,
+ latents: Optional[jnp.ndarray] = None,
+ neg_prompt_ids: Optional[jnp.ndarray] = None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -394,17 +394,17 @@ def loop_body(step, args):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
- prompt_ids: jnp.array,
- mask: jnp.array,
- masked_image: jnp.array,
+ prompt_ids: jnp.ndarray,
+ mask: jnp.ndarray,
+ masked_image: jnp.ndarray,
params: Union[Dict, FrozenDict],
- prng_seed: jax.random.KeyArray,
+ prng_seed: jax.Array,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
- guidance_scale: Union[float, jnp.array] = 7.5,
- latents: jnp.array = None,
- neg_prompt_ids: jnp.array = None,
+ guidance_scale: Union[float, jnp.ndarray] = 7.5,
+ latents: jnp.ndarray = None,
+ neg_prompt_ids: jnp.ndarray = None,
return_dict: bool = True,
jit: bool = False,
):
@@ -424,7 +424,7 @@ def __call__(
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
- latents (`jnp.array`, *optional*):
+ latents (`jnp.ndarray`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
array is generated by sampling using the supplied random `generator`.
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
index 6c8ff7fe78df..87640afbbc89 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
@@ -423,7 +423,8 @@ def __call__(
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
index b20b8ebb98e3..055d9b02c15d 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
@@ -513,7 +513,8 @@ def __call__(
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
index a8d4d2dc6019..88d300c10b55 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
@@ -524,7 +524,8 @@ def __call__(
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
index 3b9d83ee0a25..fece365af49b 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
@@ -503,7 +503,8 @@ def __call__(
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
index 4c853f7e63ac..dec4134d4326 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
@@ -555,7 +555,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# 10. Post-processing
image = self.decode_latents(latents)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
index 70095a448e32..68cdbbe78b5a 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -537,6 +537,32 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
latents = latents * self.scheduler.init_noise_sigma
return latents
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stages where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -726,7 +752,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py
index 7f09545bde88..e49e12b92ea3 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py
@@ -1003,7 +1003,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# 8. Post-processing
if not output_type == "latent":
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
index e5b334914f02..95c3a79cf0c5 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
@@ -757,7 +757,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py
index 3e328da0939c..7126b798feb5 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py
@@ -1287,7 +1287,8 @@ def invert(
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
assert len(inverted_latents) == len(timesteps)
latents = torch.stack(list(reversed(inverted_latents)), 1)
@@ -1531,7 +1532,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py
index 79dadb6fb568..f176f08d5d8c 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py
@@ -803,7 +803,9 @@ def __call__(
if gligen_inpaint_image is not None:
gligen_inpaint_latent_with_noise = (
- self.scheduler.add_noise(gligen_inpaint_latent, torch.randn_like(gligen_inpaint_latent), t)
+ self.scheduler.add_noise(
+ gligen_inpaint_latent, torch.randn_like(gligen_inpaint_latent), torch.tensor([t])
+ )
.expand(latents.shape[0], -1, -1, -1)
.clone()
)
@@ -838,7 +840,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py
index fd8fe4775386..ba418b4cb3c3 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py
@@ -965,7 +965,9 @@ def __call__(
if gligen_inpaint_image is not None:
gligen_inpaint_latent_with_noise = (
- self.scheduler.add_noise(gligen_inpaint_latent, torch.randn_like(gligen_inpaint_latent), t)
+ self.scheduler.add_noise(
+ gligen_inpaint_latent, torch.randn_like(gligen_inpaint_latent), torch.tensor([t])
+ )
.expand(latents.shape[0], -1, -1, -1)
.clone()
)
@@ -1012,7 +1014,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
index 34b4efb4e210..133311ed849c 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
@@ -392,7 +392,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
self.maybe_free_model_hooks()
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
index 2463a99f6ec5..8c180f5224b7 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -760,7 +760,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index 23f6935d8f8d..e792eb8f8c12 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -1045,7 +1045,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
condition_kwargs = {}
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
index 6dd7db93b9fc..4b555e0367c6 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
@@ -744,7 +744,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# use original latents corresponding to unmasked portions of the image
latents = (init_latents_orig * mask) + (latents * (1 - mask))
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
index f31eb197ad3f..000a9012d05c 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
@@ -378,7 +378,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
index b49de9d3e3ca..1e38142b9c66 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
@@ -22,6 +22,7 @@
from transformers import CLIPTextModel, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FromSingleFileMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import EulerDiscreteScheduler
from ...utils import deprecate, logging
@@ -59,7 +60,7 @@ def preprocess(image):
return image
-class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
+class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, FromSingleFileMixin):
r"""
Pipeline for upscaling Stable Diffusion output image resolution by a factor of 2.
@@ -472,7 +473,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py
index 343fc7e5e12f..eb3ba4b90a71 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py
@@ -674,7 +674,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py
index 8e086541a1ad..e67c04ebcf7c 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py
@@ -802,7 +802,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py
index f544020ce012..1704e28f0c7f 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py
@@ -770,7 +770,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
if circular_padding:
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
index 2250dfc93b72..6cbea1d1da7e 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
@@ -1006,7 +1006,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# 8. Compute the edit directions.
edit_direction = self.construct_direction(source_embeds, target_embeds).to(prompt_embeds.device)
@@ -1283,7 +1284,8 @@ def invert(
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
inverted_latents = latents.detach().clone()
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py
index dc59faeabdc3..42cc9905c49a 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py
@@ -712,7 +712,8 @@ def get_map_size(module, input, output):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
index 2c637e5142a4..8d01e0a0d086 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
@@ -22,7 +22,7 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
@@ -67,7 +67,9 @@ def preprocess(image):
return image
-class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
+class StableDiffusionUpscalePipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
+):
r"""
Pipeline for text-guided image super-resolution using Stable Diffusion 2.
@@ -756,7 +758,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
@@ -764,8 +767,9 @@ def __call__(
if needs_upcasting:
self.upcast_vae()
- latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+ # Ensure latents are always the same type as the VAE
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
# cast back to fp16 if needed
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
index 7bea2411c698..3b12058eda7b 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
@@ -925,7 +925,8 @@ def __call__(
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
index 7710105b46d7..3ef1994b0cb3 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
@@ -821,7 +821,8 @@ def __call__(
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# 9. Post-processing
if not output_type == "latent":
diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
index 3a8c31679540..5966600462bf 100644
--- a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
+++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
@@ -87,7 +87,7 @@ def __init__(
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
- def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
+ def init_weights(self, rng: jax.Array, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor
clip_input = jax.random.normal(rng, input_shape)
diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
index 40326c1c035b..12f4551d9de3 100644
--- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
+++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
@@ -674,7 +674,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
index 3acb5ae538a4..8f043c7c6657 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
@@ -89,7 +89,7 @@ def __call__(
self,
prompt_ids: jax.Array,
params: Union[Dict, FrozenDict],
- prng_seed: jax.random.KeyArray,
+ prng_seed: jax.Array,
num_inference_steps: int = 50,
guidance_scale: Union[float, jax.Array] = 7.5,
height: Optional[int] = None,
@@ -170,7 +170,7 @@ def _generate(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
- prng_seed: jax.random.KeyArray,
+ prng_seed: jax.Array,
num_inference_steps: int,
height: int,
width: int,
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
index c6584c1a5b40..4c1bd857d7cb 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
@@ -560,6 +560,34 @@ def upcast_vae(self):
self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype)
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
+ def enable_freeu(self, s1=0.9, s2=0.2, b1=1.2, b2=1.4):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stages where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -877,7 +905,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
index 941318abc518..9612a8e28f8e 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
@@ -1028,7 +1028,8 @@ def denoising_value_valid(dnv):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
index 854c51ea2225..209c9b339aec 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
@@ -1352,7 +1352,8 @@ def denoising_value_valid(dnv):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
index fd62d6e60942..6fd1be88b284 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
@@ -923,7 +923,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
index 7c1020792fea..2ab3bf00c8fc 100644
--- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
@@ -799,7 +799,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if output_type == "latent":
image = latents
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
index 8d50483f2e76..b32c852481ab 100644
--- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
@@ -970,7 +970,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
index 0445d600199f..42c00597beee 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
@@ -472,6 +472,34 @@ def prepare_latents(
latents = latents * self.scheduler.init_noise_sigma
return latents
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stages where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -664,7 +692,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if output_type == "latent":
return TextToVideoSDPipelineOutput(frames=latents)
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
index b6c35363de23..c571d3d6bc5e 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
@@ -736,7 +736,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if output_type == "latent":
return TextToVideoSDPipelineOutput(frames=latents)
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
index 05e821543a38..277726781eee 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
@@ -414,7 +414,8 @@ def backward_loop(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
return latents.clone().detach()
@torch.no_grad()
diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
index 6e9c9f96b0ce..7e0b07cc79ef 100644
--- a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
+++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
@@ -13,9 +13,12 @@
GPT2Tokenizer,
)
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL
+from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, is_accelerate_version, logging
+from ...utils import deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.outputs import BaseOutput
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
@@ -26,30 +29,6 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
-def preprocess(image):
- deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
- deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
- if isinstance(image, torch.Tensor):
- return image
- elif isinstance(image, PIL.Image.Image):
- image = [image]
-
- if isinstance(image[0], PIL.Image.Image):
- w, h = image[0].size
- w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
-
- image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
- image = np.concatenate(image, axis=0)
- image = np.array(image).astype(np.float32) / 255.0
- image = image.transpose(0, 3, 1, 2)
- image = 2.0 * image - 1.0
- image = torch.from_numpy(image)
- elif isinstance(image[0], torch.Tensor):
- image = torch.cat(image, dim=0)
- return image
-
-
# New BaseOutput child class for joint image-text output
@dataclass
class ImageTextPipelineOutput(BaseOutput):
@@ -111,7 +90,7 @@ def __init__(
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
image_encoder: CLIPVisionModelWithProjection,
- image_processor: CLIPImageProcessor,
+ clip_image_processor: CLIPImageProcessor,
clip_tokenizer: CLIPTokenizer,
text_decoder: UniDiffuserTextDecoder,
text_tokenizer: GPT2Tokenizer,
@@ -130,7 +109,7 @@ def __init__(
vae=vae,
text_encoder=text_encoder,
image_encoder=image_encoder,
- image_processor=image_processor,
+ clip_image_processor=clip_image_processor,
clip_tokenizer=clip_tokenizer,
text_decoder=text_decoder,
text_tokenizer=text_tokenizer,
@@ -139,6 +118,7 @@ def __init__(
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.num_channels_latents = vae.config.latent_channels
self.text_encoder_seq_len = text_encoder.config.max_position_embeddings
@@ -155,43 +135,38 @@ def __init__(
# TODO: handle safety checking?
self.safety_checker = None
- # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
- # Add self.image_encoder, self.text_decoder to cpu_offloaded_models list
- def enable_model_cpu_offload(self, gpu_id=0):
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
r"""
- Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
- to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
- method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
- `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
- if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
- from accelerate import cpu_offload_with_hook
- else:
- raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
-
- device = torch.device(f"cuda:{gpu_id}")
+ self.vae.enable_slicing()
- if self.device.type != "cpu":
- self.to("cpu", silence_dtype_warnings=True)
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
-
- hook = None
- for cpu_offloaded_model in [
- self.text_encoder.text_model,
- self.image_encoder,
- self.unet,
- self.vae,
- self.text_decoder.encode_prefix,
- self.text_decoder.decode_prefix,
- self.text_decoder,
- ]:
- _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
- if self.safety_checker is not None:
- _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
- # We'll offload the last model manually.
- self.final_offload_hook = hook
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
@@ -370,8 +345,7 @@ def _infer_batch_size(
)
return batch_size, multiplier
- # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
- # self.tokenizer => self.clip_tokenizer
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
self,
prompt,
@@ -381,6 +355,41 @@ def _encode_prompt(
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ **kwargs,
+ ):
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+ prompt_embeds_tuple = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ **kwargs,
+ )
+
+ # concatenate for backwards comp
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with self.tokenizer->self.clip_tokenizer
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
@@ -396,8 +405,8 @@ def _encode_prompt(
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
- `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
- Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -405,7 +414,23 @@ def _encode_prompt(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
"""
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not self.use_peft_backend:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
@@ -414,6 +439,10 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.clip_tokenizer)
+
text_inputs = self.clip_tokenizer(
prompt,
padding="max_length",
@@ -440,13 +469,31 @@ def _encode_prompt(
else:
attention_mask = None
- prompt_embeds = self.text_encoder(
- text_input_ids.to(device),
- attention_mask=attention_mask,
- )
- prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
@@ -458,7 +505,7 @@ def _encode_prompt(
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
- elif type(prompt) is not type(negative_prompt):
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
@@ -474,6 +521,10 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.clip_tokenizer)
+
max_length = prompt_embeds.shape[1]
uncond_input = self.clip_tokenizer(
uncond_tokens,
@@ -498,17 +549,16 @@ def _encode_prompt(
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
- # For classifier free guidance, we need to do two forward passes.
- # Here we concatenate the unconditional and text embeddings into a single batch
- # to avoid doing two forward passes
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder)
- return prompt_embeds
+ return prompt_embeds, negative_prompt_embeds
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.prepare_image_latents
# Add num_prompts_per_image argument, sample from autoencoder moment distribution
@@ -587,7 +637,7 @@ def encode_image_clip_latents(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
- preprocessed_image = self.image_processor.preprocess(
+ preprocessed_image = self.clip_image_processor.preprocess(
image,
return_tensors="pt",
)
@@ -628,17 +678,6 @@ def encode_image_clip_latents(
return image_latents
- # Note that the CLIP latents are not decoded for image generation.
- # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
- # Rename: decode_latents -> decode_image_latents
- def decode_image_latents(self, latents):
- latents = 1 / self.vae.config.scaling_factor * latents
- image = self.vae.decode(latents, return_dict=False)[0]
- image = (image / 2 + 0.5).clamp(0, 1)
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
- return image
-
def prepare_text_latents(
self, batch_size, num_images_per_prompt, seq_len, hidden_size, dtype, device, generator, latents=None
):
@@ -720,6 +759,17 @@ def prepare_image_clip_latents(
latents = latents * self.scheduler.init_noise_sigma
return latents
+ def decode_text_latents(self, text_latents, device):
+ output_token_list, seq_lengths = self.text_decoder.generate_captions(
+ text_latents, self.text_tokenizer.eos_token_id, device=device
+ )
+ output_list = output_token_list.cpu().numpy()
+ generated_text = [
+ self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True)
+ for output, length in zip(output_list, seq_lengths)
+ ]
+ return generated_text
+
def _split(self, x, height, width):
r"""
Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W)
@@ -1181,7 +1231,7 @@ def __call__(
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
# Note that this differs from the formulation in the unidiffusers paper!
- # do_classifier_free_guidance = guidance_scale > 1.0
+ do_classifier_free_guidance = guidance_scale > 1.0
# check if scheduler is in sigmas space
# scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
@@ -1194,15 +1244,18 @@ def __call__(
if mode in ["text2img"]:
# 3.1. Encode input prompt, if available
assert prompt is not None or prompt_embeds is not None
- prompt_embeds = self._encode_prompt(
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=multiplier,
- do_classifier_free_guidance=False, # don't support standard classifier-free guidance for now
+ do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
+
+ # if do_classifier_free_guidance:
+ # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
else:
# 3.2. Prepare text latent variables, if input not available
prompt_embeds = self.prepare_text_latents(
@@ -1224,7 +1277,7 @@ def __call__(
# 4.1. Encode images, if available
assert image is not None, "`img2text` requires a conditioning image"
# Encode image using VAE
- image_vae = preprocess(image)
+ image_vae = self.image_processor.preprocess(image)
height, width = image_vae.shape[-2:]
image_vae_latents = self.encode_image_vae_latents(
image=image_vae,
@@ -1321,51 +1374,46 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# 9. Post-processing
- gen_image = None
- gen_text = None
+ image = None
+ text = None
if mode == "joint":
image_vae_latents, image_clip_latents, text_latents = self._split_joint(latents, height, width)
- # Map latent VAE image back to pixel space
- gen_image = self.decode_image_latents(image_vae_latents)
+ if not output_type == "latent":
+ # Map latent VAE image back to pixel space
+ image = self.vae.decode(image_vae_latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = image_vae_latents
- # Generate text using the text decoder
- output_token_list, seq_lengths = self.text_decoder.generate_captions(
- text_latents, self.text_tokenizer.eos_token_id, device=device
- )
- output_list = output_token_list.cpu().numpy()
- gen_text = [
- self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True)
- for output, length in zip(output_list, seq_lengths)
- ]
+ text = self.decode_text_latents(text_latents, device)
elif mode in ["text2img", "img"]:
image_vae_latents, image_clip_latents = self._split(latents, height, width)
- gen_image = self.decode_image_latents(image_vae_latents)
+
+ if not output_type == "latent":
+ # Map latent VAE image back to pixel space
+ image = self.vae.decode(image_vae_latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = image_vae_latents
elif mode in ["img2text", "text"]:
text_latents = latents
- output_token_list, seq_lengths = self.text_decoder.generate_captions(
- text_latents, self.text_tokenizer.eos_token_id, device=device
- )
- output_list = output_token_list.cpu().numpy()
- gen_text = [
- self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True)
- for output, length in zip(output_list, seq_lengths)
- ]
+ text = self.decode_text_latents(text_latents, device)
self.maybe_free_model_hooks()
- # 10. Convert to PIL
- if output_type == "pil" and gen_image is not None:
- gen_image = self.numpy_to_pil(gen_image)
+ # 10. Postprocess the image, if necessary
+ if image is not None:
+ do_denormalize = [True] * image.shape[0]
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict:
- return (gen_image, gen_text)
+ return (image, text)
- return ImageTextPipelineOutput(images=gen_image, text=gen_text)
+ return ImageTextPipelineOutput(images=image, text=text)
diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
index f2b191496aaa..4e50bbefe933 100644
--- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
+++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
@@ -32,6 +32,7 @@
from ...models.transformer_2d import Transformer2DModel
from ...models.unet_2d_condition import UNet2DConditionOutput
from ...utils import is_torch_version, logging
+from ...utils.torch_utils import apply_freeu
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -749,6 +750,7 @@ def __init__(
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
+ resolution_idx=i,
resnet_groups=norm_num_groups,
cross_attention_dim=reversed_cross_attention_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
@@ -941,6 +943,38 @@ def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
+ def enable_freeu(self, s1, s2, b1, b2):
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ for i, upsample_block in enumerate(self.up_blocks):
+ setattr(upsample_block, "s1", s1)
+ setattr(upsample_block, "s2", s2)
+ setattr(upsample_block, "b1", b1)
+ setattr(upsample_block, "b2", b2)
+
+ def disable_freeu(self):
+ """Disables the FreeU mechanism."""
+ freeu_keys = {"s1", "s2", "b1", "b2"}
+ for i, upsample_block in enumerate(self.up_blocks):
+ for k in freeu_keys:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None:
+ setattr(upsample_block, k, None)
+
def forward(
self,
sample: torch.FloatTensor,
@@ -965,6 +999,26 @@ def forward(
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.FloatTensor`):
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
+ A tensor that if specified is added to the residual of the middle unet block.
encoder_attention_mask (`torch.Tensor`):
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
@@ -1630,6 +1684,7 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -1670,12 +1725,33 @@ def __init__(
self.upsamplers = None
self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
@@ -1712,6 +1788,7 @@ def __init__(
out_channels: int,
prev_output_channel: int,
temb_channels: int,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
@@ -1790,6 +1867,7 @@ def __init__(
self.upsamplers = None
self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
def forward(
self,
@@ -1803,11 +1881,30 @@ def forward(
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
index fd1eb3f7b5e7..a248c25a5592 100644
--- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
+++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
@@ -539,7 +539,8 @@ def __call__(
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
index 413d9b67dbcd..4f9c0bd9f4e7 100644
--- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
+++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
@@ -380,7 +380,8 @@ def __call__(
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
index 9c9b854b8334..24ced7620350 100644
--- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
+++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
@@ -454,7 +454,8 @@ def __call__(
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
index e4d976504c6d..6caa09a46ce0 100644
--- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
+++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
@@ -213,13 +213,13 @@ def __call__(
Image Embeddings either extracted from an image or generated by a Prior Model.
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
- num_inference_steps (`int`, *optional*, defaults to 30):
+ num_inference_steps (`int`, *optional*, defaults to 12):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
- guidance_scale (`float`, *optional*, defaults to 4.0):
+ guidance_scale (`float`, *optional*, defaults to 0.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
@@ -352,7 +352,8 @@ def __call__(
).prev_sample
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# 10. Scale and decode the image latents with vq-vae
latents = self.vqgan.config.scale_factor * latents
diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
index 6b5ce9530d4c..888d3c0dd74b 100644
--- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
+++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
@@ -194,7 +194,7 @@ def __call__(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
`prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked
to the text `prompt`, usually at the expense of lower image quality.
- prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 30):
+ prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 60):
The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. For more specific timestep spacing, you can pass customized
`prior_timesteps`
diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
index 8e737a74bbfe..dba6d7bb06db 100644
--- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
+++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
@@ -82,6 +82,12 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
scheduler ([`DDPMWuerstchenScheduler`]):
A scheduler to be used in combination with `prior` to generate image embedding.
+ latent_mean ('float', *optional*, defaults to 42.0):
+ Mean value for latent diffusers.
+ latent_std ('float', *optional*, defaults to 1.0):
+ Standard value for latent diffusers.
+ resolution_multiple ('float', *optional*, defaults to 42.67):
+ Default resolution for multiple images generated.
"""
model_cpu_offload_seq = "text_encoder->prior"
@@ -282,17 +288,17 @@ def __call__(
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
- height (`int`, *optional*, defaults to 512):
+ height (`int`, *optional*, defaults to 1024):
The height in pixels of the generated image.
- width (`int`, *optional*, defaults to 512):
+ width (`int`, *optional*, defaults to 1024):
The width in pixels of the generated image.
- num_inference_steps (`int`, *optional*, defaults to 30):
+ num_inference_steps (`int`, *optional*, defaults to 60):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
- guidance_scale (`float`, *optional*, defaults to 4.0):
+ guidance_scale (`float`, *optional*, defaults to 8.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
@@ -436,7 +442,8 @@ def __call__(
).prev_sample
if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
# 10. Denormalize the latents
latents = latents * self.config.latent_mean - self.config.latent_std
diff --git a/tests/pipelines/kandinsky_v22/__init__.py b/src/diffusers/py.typed
similarity index 100%
rename from tests/pipelines/kandinsky_v22/__init__.py
rename to src/diffusers/py.typed
diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py
index aab5255abced..5881874ab57a 100644
--- a/src/diffusers/schedulers/scheduling_ddim.py
+++ b/src/diffusers/schedulers/scheduling_ddim.py
@@ -276,13 +276,13 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
- batch_size, channels, height, width = sample.shape
+ batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
- sample = sample.reshape(batch_size, channels * height * width)
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
@@ -290,11 +290,10 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
-
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
- sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
diff --git a/src/diffusers/schedulers/scheduling_ddim_parallel.py b/src/diffusers/schedulers/scheduling_ddim_parallel.py
index f90a271dfc06..8d698f67328e 100644
--- a/src/diffusers/schedulers/scheduling_ddim_parallel.py
+++ b/src/diffusers/schedulers/scheduling_ddim_parallel.py
@@ -298,13 +298,13 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
- batch_size, channels, height, width = sample.shape
+ batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
- sample = sample.reshape(batch_size, channels * height * width)
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
@@ -312,11 +312,10 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
-
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
- sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py
index 86f7e84ff07f..bbc390a5d9ca 100644
--- a/src/diffusers/schedulers/scheduling_ddpm.py
+++ b/src/diffusers/schedulers/scheduling_ddpm.py
@@ -330,13 +330,13 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
- batch_size, channels, height, width = sample.shape
+ batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
- sample = sample.reshape(batch_size, channels * height * width)
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
@@ -344,11 +344,10 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
-
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
- sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py
index 529d2bd03a75..ab7d70f466e6 100644
--- a/src/diffusers/schedulers/scheduling_ddpm_flax.py
+++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py
@@ -198,7 +198,7 @@ def step(
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
- key: Optional[jax.random.KeyArray] = None,
+ key: Optional[jax.Array] = None,
return_dict: bool = True,
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
"""
@@ -211,7 +211,7 @@ def step(
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
- key (`jax.random.KeyArray`): a PRNG key.
+ key (`jax.Array`): a PRNG key.
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
Returns:
diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py
index 2f3bdd39aaa4..ca17ca5499e7 100644
--- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py
+++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py
@@ -344,13 +344,13 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
- batch_size, channels, height, width = sample.shape
+ batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
- sample = sample.reshape(batch_size, channels * height * width)
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
@@ -358,11 +358,10 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
-
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
- sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py
index af9b0381dcc4..a6afe744bd88 100644
--- a/src/diffusers/schedulers/scheduling_deis_multistep.py
+++ b/src/diffusers/schedulers/scheduling_deis_multistep.py
@@ -268,13 +268,13 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
- batch_size, channels, height, width = sample.shape
+ batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
- sample = sample.reshape(batch_size, channels * height * width)
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
@@ -282,11 +282,10 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
-
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
- sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
index 726ad138ad84..6b1a43630fa6 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
@@ -288,13 +288,13 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
- batch_size, channels, height, width = sample.shape
+ batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
- sample = sample.reshape(batch_size, channels * height * width)
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
@@ -302,11 +302,10 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
-
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
- sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
index c0b286a37060..fa8f362bd3b5 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
@@ -298,13 +298,13 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
- batch_size, channels, height, width = sample.shape
+ batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
- sample = sample.reshape(batch_size, channels * height * width)
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
@@ -312,11 +312,10 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
-
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
- sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
index 6744a68b4c4b..bb7dc21e6fdb 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
@@ -302,13 +302,13 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
- batch_size, channels, height, width = sample.shape
+ batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
- sample = sample.reshape(batch_size, channels * height * width)
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
@@ -316,11 +316,10 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
-
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
- sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py
index 45c0dbddf7ef..4a8606007d5f 100644
--- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py
+++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py
@@ -17,6 +17,7 @@
from typing import Optional, Tuple, Union
import flax
+import jax
import jax.numpy as jnp
from jax import random
@@ -139,7 +140,7 @@ def add_noise_to_input(
state: KarrasVeSchedulerState,
sample: jnp.ndarray,
sigma: float,
- key: random.KeyArray,
+ key: jax.Array,
) -> Tuple[jnp.ndarray, float]:
"""
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py
index b6240559fc88..935f972a9bdb 100644
--- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py
+++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py
@@ -18,6 +18,7 @@
from typing import Optional, Tuple, Union
import flax
+import jax
import jax.numpy as jnp
from jax import random
@@ -169,7 +170,7 @@ def step_pred(
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
- key: random.KeyArray,
+ key: jax.Array,
return_dict: bool = True,
) -> Union[FlaxSdeVeOutput, Tuple]:
"""
@@ -228,7 +229,7 @@ def step_correct(
state: ScoreSdeVeSchedulerState,
model_output: jnp.ndarray,
sample: jnp.ndarray,
- key: random.KeyArray,
+ key: jax.Array,
return_dict: bool = True,
) -> Union[FlaxSdeVeOutput, Tuple]:
"""
diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py
index 2b5bd4fd60db..741b03b6d3a2 100644
--- a/src/diffusers/schedulers/scheduling_unipc_multistep.py
+++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py
@@ -282,13 +282,13 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
- batch_size, channels, height, width = sample.shape
+ batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
- sample = sample.reshape(batch_size, channels * height * width)
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
@@ -296,11 +296,10 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
-
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
- sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
@@ -534,14 +533,14 @@ def multistep_uni_p_bh_update(
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
- pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - alpha_t * B_h * pred_res
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
- pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - sigma_t * B_h * pred_res
@@ -670,7 +669,7 @@ def multistep_uni_c_bh_update(
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
- corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
@@ -678,7 +677,7 @@ def multistep_uni_c_bh_update(
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
- corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py
index b9e60a2a873b..1f51f2c0497b 100644
--- a/src/diffusers/utils/constants.py
+++ b/src/diffusers/utils/constants.py
@@ -25,7 +25,7 @@
ONNX_WEIGHTS_NAME = "model.onnx"
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
-HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
+HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py
index 99ea4d8cf1d0..7955ccb01d85 100644
--- a/src/diffusers/utils/torch_utils.py
+++ b/src/diffusers/utils/torch_utils.py
@@ -22,6 +22,7 @@
if is_torch_available():
import torch
+ from torch.fft import fftn, fftshift, ifftn, ifftshift
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -86,3 +87,61 @@ def is_compiled_module(module):
if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
return False
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
+
+
+def fourier_filter(x_in, threshold, scale):
+ """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
+
+ This version of the method comes from here:
+ https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706
+ """
+ x = x_in
+ B, C, H, W = x.shape
+
+ # Non-power of 2 images must be float32
+ if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
+ x = x.to(dtype=torch.float32)
+
+ # FFT
+ x_freq = fftn(x, dim=(-2, -1))
+ x_freq = fftshift(x_freq, dim=(-2, -1))
+
+ B, C, H, W = x_freq.shape
+ mask = torch.ones((B, C, H, W), device=x.device)
+
+ crow, ccol = H // 2, W // 2
+ mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale
+ x_freq = x_freq * mask
+
+ # IFFT
+ x_freq = ifftshift(x_freq, dim=(-2, -1))
+ x_filtered = ifftn(x_freq, dim=(-2, -1)).real
+
+ return x_filtered.to(dtype=x_in.dtype)
+
+
+def apply_freeu(
+ resolution_idx: int, hidden_states: torch.Tensor, res_hidden_states: torch.Tensor, **freeu_kwargs
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Applies the FreeU mechanism as introduced in https:
+ //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU.
+
+ Args:
+ resolution_idx (`int`): Integer denoting the UNet block where FreeU is being applied.
+ hidden_states (`torch.Tensor`): Inputs to the underlying block.
+ res_hidden_states (`torch.Tensor`): Features from the skip block corresponding to the underlying block.
+ s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features.
+ s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if resolution_idx == 0:
+ num_half_channels = hidden_states.shape[1] // 2
+ hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b1"]
+ res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s1"])
+ if resolution_idx == 1:
+ num_half_channels = hidden_states.shape[1] // 2
+ hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b2"]
+ res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"])
+
+ return hidden_states, res_hidden_states
diff --git a/tests/lora/test_lora_layers_old_backend.py b/tests/lora/test_lora_layers_old_backend.py
index ae90f8b6a4b8..8c1fb4877653 100644
--- a/tests/lora/test_lora_layers_old_backend.py
+++ b/tests/lora/test_lora_layers_old_backend.py
@@ -1028,6 +1028,47 @@ def test_load_lora_locally_safetensors(self):
sd_pipe.unload_lora_weights()
+ def test_lora_fuse_nan(self):
+ pipeline_components, lora_components = self.get_dummy_components()
+ sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
+
+ # Emulate training.
+ set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
+ set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
+ set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ StableDiffusionXLPipeline.save_lora_weights(
+ save_directory=tmpdirname,
+ unet_lora_layers=lora_components["unet_lora_layers"],
+ text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
+ text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
+ safe_serialization=True,
+ )
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
+ sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
+
+ # corrupt one LoRA weight with `inf` values
+ with torch.no_grad():
+ sd_pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_layer.down.weight += float(
+ "inf"
+ )
+
+ # with `safe_fusing=True` we should see an Error
+ with self.assertRaises(ValueError):
+ sd_pipe.fuse_lora(safe_fusing=True)
+
+ # without we should not see an error, but every image will be black
+ sd_pipe.fuse_lora(safe_fusing=False)
+
+ out = sd_pipe("test", num_inference_steps=2, output_type="np").images
+
+ assert np.isnan(out).all()
+
def test_lora_fusion(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
@@ -1554,7 +1595,7 @@ def test_lora_on_off(self, expected_max_diff=1e-3):
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
- def test_lora_xformers_on_off(self, expected_max_diff=1e-4):
+ def test_lora_xformers_on_off(self, expected_max_diff=6e-4):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py
index 32e3729828fa..64baeea910b8 100644
--- a/tests/pipelines/controlnet/test_controlnet.py
+++ b/tests/pipelines/controlnet/test_controlnet.py
@@ -85,16 +85,17 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
prompt = "bird"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
- )
+ ).resize((512, 512))
- output = pipe(prompt, image, generator=generator, output_type="np")
+ output = pipe(prompt, image, num_inference_steps=10, generator=generator, output_type="np")
image = output.images[0]
- assert image.shape == (768, 512, 3)
+ assert image.shape == (512, 512, 3)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy"
)
+ expected_image = np.resize(expected_image, (512, 512, 3))
assert np.abs(expected_image - image).max() < 1.0
@@ -118,7 +119,7 @@ class ControlNetPipelineFastTests(
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
+ block_out_channels=(4, 8),
layers_per_block=2,
sample_size=32,
in_channels=4,
@@ -126,15 +127,17 @@ def get_dummy_components(self):
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
+ norm_num_groups=1,
)
torch.manual_seed(0)
controlnet = ControlNetModel(
- block_out_channels=(32, 64),
+ block_out_channels=(4, 8),
layers_per_block=2,
in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32),
+ norm_num_groups=1,
)
torch.manual_seed(0)
scheduler = DDIMScheduler(
@@ -146,12 +149,13 @@ def get_dummy_components(self):
)
torch.manual_seed(0)
vae = AutoencoderKL(
- block_out_channels=[32, 64],
+ block_out_channels=[4, 8],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
+ norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
@@ -229,7 +233,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
+ block_out_channels=(4, 8),
layers_per_block=2,
sample_size=32,
in_channels=4,
@@ -237,6 +241,7 @@ def get_dummy_components(self):
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
+ norm_num_groups=1,
)
torch.manual_seed(0)
@@ -246,23 +251,25 @@ def init_weights(m):
m.bias.data.fill_(1.0)
controlnet1 = ControlNetModel(
- block_out_channels=(32, 64),
+ block_out_channels=(4, 8),
layers_per_block=2,
in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32),
+ norm_num_groups=1,
)
controlnet1.controlnet_down_blocks.apply(init_weights)
torch.manual_seed(0)
controlnet2 = ControlNetModel(
- block_out_channels=(32, 64),
+ block_out_channels=(4, 8),
layers_per_block=2,
in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32),
+ norm_num_groups=1,
)
controlnet2.controlnet_down_blocks.apply(init_weights)
@@ -276,12 +283,13 @@ def init_weights(m):
)
torch.manual_seed(0)
vae = AutoencoderKL(
- block_out_channels=[32, 64],
+ block_out_channels=[4, 8],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
+ norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
@@ -414,7 +422,7 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests(
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
+ block_out_channels=(4, 8),
layers_per_block=2,
sample_size=32,
in_channels=4,
@@ -422,6 +430,7 @@ def get_dummy_components(self):
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
+ norm_num_groups=1,
)
torch.manual_seed(0)
@@ -431,12 +440,13 @@ def init_weights(m):
m.bias.data.fill_(1.0)
controlnet = ControlNetModel(
- block_out_channels=(32, 64),
+ block_out_channels=(4, 8),
layers_per_block=2,
in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32),
+ norm_num_groups=1,
)
controlnet.controlnet_down_blocks.apply(init_weights)
@@ -450,12 +460,13 @@ def init_weights(m):
)
torch.manual_seed(0)
vae = AutoencoderKL(
- block_out_channels=[32, 64],
+ block_out_channels=[4, 8],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
+ norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py
index 1ec1f493b9f0..a9140f3d5a31 100644
--- a/tests/pipelines/controlnet/test_controlnet_inpaint.py
+++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py
@@ -39,6 +39,7 @@
enable_full_determinism,
floats_tensor,
load_numpy,
+ numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
@@ -550,7 +551,7 @@ def make_inpaint_condition(image, image_mask):
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/boy_ray_ban.npy"
)
- assert np.abs(expected_image - image).max() < 0.9e-1
+ assert numpy_cosine_similarity_distance(expected_image.flatten(), image.flatten()) < 1e-2
def test_load_local(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
diff --git a/tests/pipelines/text_to_video/__init__.py b/tests/pipelines/kandinsky2_2/__init__.py
similarity index 100%
rename from tests/pipelines/text_to_video/__init__.py
rename to tests/pipelines/kandinsky2_2/__init__.py
diff --git a/tests/pipelines/kandinsky_v22/test_kandinsky.py b/tests/pipelines/kandinsky2_2/test_kandinsky.py
similarity index 100%
rename from tests/pipelines/kandinsky_v22/test_kandinsky.py
rename to tests/pipelines/kandinsky2_2/test_kandinsky.py
diff --git a/tests/pipelines/kandinsky_v22/test_kandinsky_combined.py b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
similarity index 100%
rename from tests/pipelines/kandinsky_v22/test_kandinsky_combined.py
rename to tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
diff --git a/tests/pipelines/kandinsky_v22/test_kandinsky_controlnet.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
similarity index 98%
rename from tests/pipelines/kandinsky_v22/test_kandinsky_controlnet.py
rename to tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
index cec209c7cfec..74a912faa33f 100644
--- a/tests/pipelines/kandinsky_v22/test_kandinsky_controlnet.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
@@ -221,6 +221,9 @@ def test_kandinsky_controlnet(self):
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
+ def test_inference_batch_single_identical(self):
+ super().test_inference_batch_single_identical(expected_max_diff=5e-4)
+
@nightly
@require_torch_gpu
diff --git a/tests/pipelines/kandinsky_v22/test_kandinsky_controlnet_img2img.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py
similarity index 100%
rename from tests/pipelines/kandinsky_v22/test_kandinsky_controlnet_img2img.py
rename to tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py
diff --git a/tests/pipelines/kandinsky_v22/test_kandinsky_img2img.py b/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
similarity index 100%
rename from tests/pipelines/kandinsky_v22/test_kandinsky_img2img.py
rename to tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
diff --git a/tests/pipelines/kandinsky_v22/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
similarity index 100%
rename from tests/pipelines/kandinsky_v22/test_kandinsky_inpaint.py
rename to tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
diff --git a/tests/pipelines/kandinsky_v22/test_kandinsky_prior.py b/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py
similarity index 100%
rename from tests/pipelines/kandinsky_v22/test_kandinsky_prior.py
rename to tests/pipelines/kandinsky2_2/test_kandinsky_prior.py
diff --git a/tests/pipelines/kandinsky_v22/test_kandinsky_prior_emb2emb.py b/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py
similarity index 100%
rename from tests/pipelines/kandinsky_v22/test_kandinsky_prior_emb2emb.py
rename to tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py
diff --git a/tests/pipelines/shap_e/test_shap_e_img2img.py b/tests/pipelines/shap_e/test_shap_e_img2img.py
index 55c0ae6bd02e..055dbe7a97d4 100644
--- a/tests/pipelines/shap_e/test_shap_e_img2img.py
+++ b/tests/pipelines/shap_e/test_shap_e_img2img.py
@@ -250,7 +250,7 @@ def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
def test_save_load_local(self):
- super().test_save_load_local(expected_max_difference=1e-3)
+ super().test_save_load_local(expected_max_difference=5e-3)
@unittest.skip("Key error is raised with accelerate")
def test_sequential_cpu_offload_forward_pass(self):
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
index df9e8d47f1b1..d6a63b98912a 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
@@ -565,6 +565,47 @@ def test_attention_slicing_forward_pass(self):
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
+ def test_freeu_enabled(self):
+ components = self.get_dummy_components()
+ sd_pipe = StableDiffusionPipeline(**components)
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "hey"
+ output = sd_pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images
+
+ sd_pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
+ output_freeu = sd_pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images
+
+ assert not np.allclose(
+ output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
+ ), "Enabling of FreeU should lead to different results."
+
+ def test_freeu_disabled(self):
+ components = self.get_dummy_components()
+ sd_pipe = StableDiffusionPipeline(**components)
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "hey"
+ output = sd_pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images
+
+ sd_pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
+ sd_pipe.disable_freeu()
+
+ freeu_keys = {"s1", "s2", "b1", "b2"}
+ for upsample_block in sd_pipe.unet.up_blocks:
+ for key in freeu_keys:
+ assert getattr(upsample_block, key) is None, f"Disabling of FreeU should have set {key} to None."
+
+ output_no_freeu = sd_pipe(
+ prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)
+ ).images
+
+ assert np.allclose(
+ output[0, -3:, -3:, -1], output_no_freeu[0, -3:, -3:, -1]
+ ), "Disabling of FreeU should lead to results similar to the default pipeline results."
+
@slow
@require_torch_gpu
@@ -600,6 +641,20 @@ def test_stable_diffusion_1_1_pndm(self):
expected_slice = np.array([0.43625, 0.43554, 0.36670, 0.40660, 0.39703, 0.38658, 0.43936, 0.43557, 0.40592])
assert np.abs(image_slice - expected_slice).max() < 3e-3
+ def test_stable_diffusion_v1_4_with_freeu(self):
+ sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_inputs(torch_device)
+ inputs["num_inference_steps"] = 25
+
+ sd_pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
+ image = sd_pipe(**inputs).images
+ image = image[0, -3:, -3:, -1].flatten()
+ expected_image = [0.0721, 0.0588, 0.0268, 0.0384, 0.0636, 0.0, 0.0429, 0.0344, 0.0309]
+ max_diff = np.abs(expected_image - image).max()
+ assert max_diff < 1e-3
+
def test_stable_diffusion_1_4_pndm(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
sd_pipe = sd_pipe.to(torch_device)
@@ -1079,7 +1134,7 @@ def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0
"generator": generator,
"num_inference_steps": 50,
"guidance_scale": 7.5,
- "output_type": "numpy",
+ "output_type": "np",
}
return inputs
@@ -1155,19 +1210,3 @@ def test_stable_diffusion_euler(self):
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
-
- def test_stable_diffusion_dpm(self):
- sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(torch_device)
- sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- inputs["num_inference_steps"] = 25
- image = sd_pipe(**inputs).images[0]
-
- expected_image = load_numpy(
- "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
- "/stable_diffusion_text2img/stable_diffusion_1_4_dpm_multi.npy"
- )
- max_diff = np.abs(expected_image - image).max()
- assert max_diff < 1e-3
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_gligen.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_gligen.py
index 19d44e0cd1d9..388ad9672e15 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_gligen.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_gligen.py
@@ -22,6 +22,7 @@
from diffusers import (
AutoencoderKL,
DDIMScheduler,
+ EulerAncestralDiscreteScheduler,
StableDiffusionGLIGENPipeline,
UNet2DConditionModel,
)
@@ -120,7 +121,7 @@ def get_dummy_inputs(self, device, seed=0):
}
return inputs
- def test_gligen(self):
+ def test_stable_diffusion_gligen_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionGLIGENPipeline(**components)
@@ -136,6 +137,24 @@ def test_gligen(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ def test_stable_diffusion_gligen_k_euler_ancestral(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ sd_pipe = StableDiffusionGLIGENPipeline(**components)
+ sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ output = sd_pipe(**inputs)
+ image = output.images
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+ expected_slice = np.array([0.425, 0.494, 0.429, 0.469, 0.525, 0.417, 0.533, 0.5, 0.47])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_gligen_text_image.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_gligen_text_image.py
index 4e14adc81f42..f8f32643aec1 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_gligen_text_image.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_gligen_text_image.py
@@ -29,6 +29,7 @@
from diffusers import (
AutoencoderKL,
DDIMScheduler,
+ EulerAncestralDiscreteScheduler,
StableDiffusionGLIGENTextImagePipeline,
UNet2DConditionModel,
)
@@ -150,7 +151,7 @@ def get_dummy_inputs(self, device, seed=0):
}
return inputs
- def test_gligen(self):
+ def test_stable_diffusion_gligen_text_image_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionGLIGENTextImagePipeline(**components)
@@ -166,6 +167,24 @@ def test_gligen(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ def test_stable_diffusion_gligen_k_euler_ancestral(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ sd_pipe = StableDiffusionGLIGENTextImagePipeline(**components)
+ sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+
+ expected_slice = np.array([0.425, 0.494, 0.429, 0.469, 0.525, 0.417, 0.533, 0.5, 0.47])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py
index 2c0f37519ad8..aa5b3e38b0c1 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py
@@ -29,6 +29,7 @@
floats_tensor,
load_image,
load_numpy,
+ numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
@@ -479,3 +480,36 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 2.9 GB is allocated
assert mem_bytes < 2.9 * 10**9
+
+ def test_download_ckpt_diff_format_is_same(self):
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/sd2-upscale/low_res_cat.png"
+ )
+
+ prompt = "a cat sitting on a park bench"
+ model_id = "stabilityai/stable-diffusion-x4-upscaler"
+ pipe = StableDiffusionUpscalePipeline.from_pretrained(model_id)
+ pipe.enable_model_cpu_offload()
+
+ generator = torch.Generator("cpu").manual_seed(0)
+ output = pipe(prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3)
+ image_from_pretrained = output.images[0]
+
+ single_file_path = (
+ "https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors"
+ )
+ pipe_from_single_file = StableDiffusionUpscalePipeline.from_single_file(single_file_path)
+ pipe_from_single_file.enable_model_cpu_offload()
+
+ generator = torch.Generator("cpu").manual_seed(0)
+ output_from_single_file = pipe_from_single_file(
+ prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3
+ )
+ image_from_single_file = output_from_single_file.images[0]
+
+ assert image_from_pretrained.shape == (512, 512, 3)
+ assert image_from_single_file.shape == (512, 512, 3)
+ assert (
+ numpy_cosine_similarity_distance(image_from_pretrained.flatten(), image_from_single_file.flatten()) < 1e-3
+ )
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
index 65c7526e3aa2..cebd860a4379 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
@@ -51,7 +51,7 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
+ block_out_channels=(2, 4),
layers_per_block=2,
sample_size=32,
in_channels=4,
@@ -66,6 +66,7 @@ def get_dummy_components(self):
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
+ norm_num_groups=1,
)
scheduler = EulerDiscreteScheduler(
beta_start=0.00085,
@@ -144,7 +145,7 @@ def test_stable_diffusion_xl_euler(self):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5873, 0.6128, 0.4797, 0.5122, 0.5674, 0.4639, 0.5227, 0.5149, 0.4747])
+ expected_slice = np.array([0.5552, 0.5569, 0.4725, 0.4348, 0.4994, 0.4632, 0.5142, 0.5012, 0.47])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/test_pipelines_flax.py b/tests/pipelines/test_pipelines_flax.py
index 294dad5ff0f1..fa2283d7a6b9 100644
--- a/tests/pipelines/test_pipelines_flax.py
+++ b/tests/pipelines/test_pipelines_flax.py
@@ -110,7 +110,7 @@ def test_stable_diffusion_v1_4(self):
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
- assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-3
+ assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-2
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 5e-1
def test_stable_diffusion_v1_4_bfloat_16(self):
@@ -139,7 +139,7 @@ def test_stable_diffusion_v1_4_bfloat_16(self):
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
- assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 1e-3
+ assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 5e-2
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
@@ -168,7 +168,7 @@ def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
- assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 1e-3
+ assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 5e-2
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
@@ -212,7 +212,7 @@ def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
- assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 1e-3
+ assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 5e-2
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1
def test_jax_memory_efficient_attention(self):
diff --git a/tests/pipelines/text_to_video_synthesis/__init__.py b/tests/pipelines/text_to_video_synthesis/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/text_to_video/test_text_to_video.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py
similarity index 89%
rename from tests/pipelines/text_to_video/test_text_to_video.py
rename to tests/pipelines/text_to_video_synthesis/test_text_to_video.py
index 2c47dc492da1..933583ce4b70 100644
--- a/tests/pipelines/text_to_video/test_text_to_video.py
+++ b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py
@@ -193,3 +193,21 @@ def test_two_step_model(self):
video = video_frames.cpu().numpy()
assert np.abs(expected_video - video).mean() < 5e-2
+
+ def test_two_step_model_with_freeu(self):
+ expected_video = []
+
+ pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b")
+ pipe = pipe.to(torch_device)
+
+ prompt = "Spiderman is surfing"
+ generator = torch.Generator(device="cpu").manual_seed(0)
+
+ pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
+ video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames
+ video = video_frames.cpu().numpy()
+ video = video[0, 0, -3:, -3:, -1].flatten()
+
+ expected_video = [-0.3102, -0.2477, -0.1772, -0.648, -0.6176, -0.5484, -0.0217, -0.056, -0.0177]
+
+ assert np.abs(expected_video - video).mean() < 5e-2
diff --git a/tests/pipelines/text_to_video/test_text_to_video_zero.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero.py
similarity index 100%
rename from tests/pipelines/text_to_video/test_text_to_video_zero.py
rename to tests/pipelines/text_to_video_synthesis/test_text_to_video_zero.py
diff --git a/tests/pipelines/text_to_video/test_video_to_video.py b/tests/pipelines/text_to_video_synthesis/test_video_to_video.py
similarity index 100%
rename from tests/pipelines/text_to_video/test_video_to_video.py
rename to tests/pipelines/text_to_video_synthesis/test_video_to_video.py
diff --git a/tests/pipelines/unidiffuser/test_unidiffuser.py b/tests/pipelines/unidiffuser/test_unidiffuser.py
index 01eee68c76be..ba8026db6154 100644
--- a/tests/pipelines/unidiffuser/test_unidiffuser.py
+++ b/tests/pipelines/unidiffuser/test_unidiffuser.py
@@ -1,5 +1,6 @@
import gc
import random
+import traceback
import unittest
import numpy as np
@@ -20,17 +21,70 @@
UniDiffuserPipeline,
UniDiffuserTextDecoder,
)
-from diffusers.utils.testing_utils import floats_tensor, load_image, nightly, require_torch_gpu, torch_device
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ load_image,
+ nightly,
+ require_torch_2,
+ require_torch_gpu,
+ run_test_in_subprocess,
+ torch_device,
+)
from diffusers.utils.torch_utils import randn_tensor
-from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
+from ..pipeline_params import (
+ IMAGE_TO_IMAGE_IMAGE_PARAMS,
+ TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
+ TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
+)
+from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
+
+
+enable_full_determinism()
+
+# Will be run via run_test_in_subprocess
+def _test_unidiffuser_compile(in_queue, out_queue, timeout):
+ error = None
+ try:
+ inputs = in_queue.get(timeout=timeout)
+ torch_device = inputs.pop("torch_device")
+ seed = inputs.pop("seed")
+ inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed)
+
+ pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")
+ # pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
+ pipe = pipe.to(torch_device)
+
+ pipe.unet.to(memory_format=torch.channels_last)
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+
+ pipe.set_progress_bar_config(disable=None)
-class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ image = pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1].flatten()
+
+ assert image.shape == (1, 512, 512, 3)
+ expected_slice = np.array([0.2402, 0.2375, 0.2285, 0.2378, 0.2407, 0.2263, 0.2354, 0.2307, 0.2520])
+ assert np.abs(image_slice - expected_slice).max() < 1e-1
+ except Exception:
+ error = f"{traceback.format_exc()}"
+
+ results = {"error": error}
+ out_queue.put(results, timeout=timeout)
+ out_queue.join()
+
+
+class UniDiffuserPipelineFastTests(
+ PipelineTesterMixin, PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
+):
pipeline_class = UniDiffuserPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
+ image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
+ # vae_latents, not latents, is the argument that corresponds to VAE latent inputs
+ image_latents_params = frozenset(["vae_latents"])
def get_dummy_components(self):
unet = UniDiffuserModel.from_pretrained(
@@ -64,7 +118,7 @@ def get_dummy_components(self):
subfolder="image_encoder",
)
# From the Stable Diffusion Image Variation pipeline tests
- image_processor = CLIPImageProcessor(crop_size=32, size=32)
+ clip_image_processor = CLIPImageProcessor(crop_size=32, size=32)
# image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip")
text_tokenizer = GPT2Tokenizer.from_pretrained(
@@ -80,7 +134,7 @@ def get_dummy_components(self):
"vae": vae,
"text_encoder": text_encoder,
"image_encoder": image_encoder,
- "image_processor": image_processor,
+ "clip_image_processor": clip_image_processor,
"clip_tokenizer": clip_tokenizer,
"text_decoder": text_decoder,
"text_tokenizer": text_tokenizer,
@@ -619,6 +673,19 @@ def test_unidiffuser_default_img2text_v1(self):
expected_text_prefix = "An astronaut"
assert text[0][: len(expected_text_prefix)] == expected_text_prefix
+ @unittest.skip(reason="Skip torch.compile test to speed up the slow test suite.")
+ @require_torch_2
+ def test_unidiffuser_compile(self, seed=0):
+ inputs = self.get_inputs(torch_device, seed=seed, generate_latents=True)
+ # Delete prompt and image for joint inference.
+ del inputs["prompt"]
+ del inputs["image"]
+ # Can't pickle a Generator object
+ del inputs["generator"]
+ inputs["torch_device"] = torch_device
+ inputs["seed"] = seed
+ run_test_in_subprocess(test_case=self, target_func=_test_unidiffuser_compile, inputs=inputs)
+
@nightly
@require_torch_gpu
diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py
index 1b9a464ba6de..be41cea95b67 100644
--- a/tests/schedulers/test_scheduler_unipc.py
+++ b/tests/schedulers/test_scheduler_unipc.py
@@ -269,3 +269,113 @@ def test_full_loop_with_noise(self):
assert abs(result_sum.item() - 315.5757) < 1e-2, f" expected result sum 315.5757, but get {result_sum}"
assert abs(result_mean.item() - 0.4109) < 1e-3, f" expected result mean 0.4109, but get {result_mean}"
+
+
+class UniPCMultistepScheduler1DTest(UniPCMultistepSchedulerTest):
+ @property
+ def dummy_sample(self):
+ batch_size = 4
+ num_channels = 3
+ width = 8
+
+ sample = torch.rand((batch_size, num_channels, width))
+
+ return sample
+
+ @property
+ def dummy_noise_deter(self):
+ batch_size = 4
+ num_channels = 3
+ width = 8
+
+ num_elems = batch_size * num_channels * width
+ sample = torch.arange(num_elems).flip(-1)
+ sample = sample.reshape(num_channels, width, batch_size)
+ sample = sample / num_elems
+ sample = sample.permute(2, 0, 1)
+
+ return sample
+
+ @property
+ def dummy_sample_deter(self):
+ batch_size = 4
+ num_channels = 3
+ width = 8
+
+ num_elems = batch_size * num_channels * width
+ sample = torch.arange(num_elems)
+ sample = sample.reshape(num_channels, width, batch_size)
+ sample = sample / num_elems
+ sample = sample.permute(2, 0, 1)
+
+ return sample
+
+ def test_switch(self):
+ # make sure that iterating over schedulers with same config names gives same results
+ # for defaults
+ scheduler = UniPCMultistepScheduler(**self.get_scheduler_config())
+ sample = self.full_loop(scheduler=scheduler)
+ result_mean = torch.mean(torch.abs(sample))
+
+ assert abs(result_mean.item() - 0.2441) < 1e-3
+
+ scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
+ scheduler = DEISMultistepScheduler.from_config(scheduler.config)
+ scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
+ scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
+
+ sample = self.full_loop(scheduler=scheduler)
+ result_mean = torch.mean(torch.abs(sample))
+
+ assert abs(result_mean.item() - 0.2441) < 1e-3
+
+ def test_full_loop_no_noise(self):
+ sample = self.full_loop()
+ result_mean = torch.mean(torch.abs(sample))
+
+ assert abs(result_mean.item() - 0.2441) < 1e-3
+
+ def test_full_loop_with_karras(self):
+ sample = self.full_loop(use_karras_sigmas=True)
+ result_mean = torch.mean(torch.abs(sample))
+
+ assert abs(result_mean.item() - 0.2898) < 1e-3
+
+ def test_full_loop_with_v_prediction(self):
+ sample = self.full_loop(prediction_type="v_prediction")
+ result_mean = torch.mean(torch.abs(sample))
+
+ assert abs(result_mean.item() - 0.1014) < 1e-3
+
+ def test_full_loop_with_karras_and_v_prediction(self):
+ sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True)
+ result_mean = torch.mean(torch.abs(sample))
+
+ assert abs(result_mean.item() - 0.1944) < 1e-3
+
+ def test_full_loop_with_noise(self):
+ scheduler_class = self.scheduler_classes[0]
+ scheduler_config = self.get_scheduler_config()
+ scheduler = scheduler_class(**scheduler_config)
+
+ num_inference_steps = 10
+ t_start = 8
+
+ model = self.dummy_model()
+ sample = self.dummy_sample_deter
+ scheduler.set_timesteps(num_inference_steps)
+
+ # add noise
+ noise = self.dummy_noise_deter
+ timesteps = scheduler.timesteps[t_start * scheduler.order :]
+ sample = scheduler.add_noise(sample, noise, timesteps[:1])
+
+ for i, t in enumerate(timesteps):
+ residual = model(sample, t)
+ sample = scheduler.step(residual, t, sample).prev_sample
+
+ result_sum = torch.sum(torch.abs(sample))
+ result_mean = torch.mean(torch.abs(sample))
+
+ assert abs(result_sum.item() - 39.0870) < 1e-2, f" expected result sum 39.0870, but get {result_sum}"
+ assert abs(result_mean.item() - 0.4072) < 1e-3, f" expected result mean 0.4072, but get {result_mean}"
diff --git a/utils/fetch_torch_cuda_pipeline_test_matrix.py b/utils/fetch_torch_cuda_pipeline_test_matrix.py
new file mode 100644
index 000000000000..41a9c1c8270d
--- /dev/null
+++ b/utils/fetch_torch_cuda_pipeline_test_matrix.py
@@ -0,0 +1,96 @@
+import json
+import logging
+import os
+from collections import defaultdict
+from pathlib import Path
+
+from huggingface_hub import HfApi, ModelFilter
+
+import diffusers
+
+
+PATH_TO_REPO = Path(__file__).parent.parent.resolve()
+ALWAYS_TEST_PIPELINE_MODULES = [
+ "controlnet",
+ "stable_diffusion",
+ "stable_diffusion_2",
+ "stable_diffusion_xl",
+ "deepfloyd_if",
+ "kandinsky",
+ "kandinsky2_2",
+ "text_to_video_synthesis",
+ "wuerstchen",
+]
+PIPELINE_USAGE_CUTOFF = int(os.getenv("PIPELINE_USAGE_CUTOFF", 50000))
+
+logger = logging.getLogger(__name__)
+api = HfApi()
+filter = ModelFilter(library="diffusers")
+
+
+def filter_pipelines(usage_dict, usage_cutoff=10000):
+ output = []
+ for diffusers_object, usage in usage_dict.items():
+ if usage < usage_cutoff:
+ continue
+
+ if "Pipeline" in diffusers_object:
+ output.append(diffusers_object)
+
+ return output
+
+
+def fetch_pipeline_objects():
+ models = api.list_models(filter=filter)
+ downloads = defaultdict(int)
+
+ for model in models:
+ is_counted = False
+ for tag in model.tags:
+ if tag.startswith("diffusers:"):
+ is_counted = True
+ downloads[tag[len("diffusers:") :]] += model.downloads
+
+ if not is_counted:
+ downloads["other"] += model.downloads
+
+ # Remove 0 downloads
+ downloads = {k: v for k, v in downloads.items() if v > 0}
+ pipeline_objects = filter_pipelines(downloads, PIPELINE_USAGE_CUTOFF)
+
+ return pipeline_objects
+
+
+def fetch_pipeline_modules_to_test():
+ try:
+ pipeline_objects = fetch_pipeline_objects()
+ except Exception as e:
+ logger.error(e)
+ raise RuntimeError("Unable to fetch model list from HuggingFace Hub.")
+
+ test_modules = []
+ for pipeline_name in pipeline_objects:
+ module = getattr(diffusers, pipeline_name)
+ test_module = module.__module__.split(".")[-2].strip()
+ test_modules.append(test_module)
+
+ return test_modules
+
+
+def main():
+ test_modules = fetch_pipeline_modules_to_test()
+ test_modules.extend(ALWAYS_TEST_PIPELINE_MODULES)
+
+ # Get unique modules
+ test_modules = list(set(test_modules))
+ print(json.dumps(test_modules))
+
+ save_path = f"{PATH_TO_REPO}/reports"
+ os.makedirs(save_path, exist_ok=True)
+
+ with open(f"{save_path}/test-pipelines.json", "w") as f:
+ json.dump({"pipeline_test_modules": test_modules}, f)
+
+
+if __name__ == "__main__":
+ main()