Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions examples/text_to_image/README_sdxl.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ from accelerate.utils import write_basic_config
write_basic_config()
```

When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.

### Training

Expand Down Expand Up @@ -73,10 +73,10 @@ accelerate launch train_text_to_image_sdxl.py \
--push_to_hub
```

**Notes**:
**Notes**:

* The `train_text_to_image_sdxl.py` script pre-computes text embeddings and the VAE encodings and keeps them in memory. While for smaller datasets like [`lambdalabs/pokemon-blip-captions`](https://hf.co/datasets/lambdalabs/pokemon-blip-captions), it might not be a problem, it can definitely lead to memory problems when the script is used on a larger dataset. For those purposes, you would want to serialize these pre-computed representations to disk separately and load them during the fine-tuning process. Refer to [this PR](https://github.com/huggingface/diffusers/pull/4505) for a more in-depth discussion.
* The training script is compute-intensive and may not run on a consumer GPU like Tesla T4.
* The `train_text_to_image_sdxl.py` script pre-computes text embeddings and the VAE encodings and keeps them in memory. While for smaller datasets like [`lambdalabs/pokemon-blip-captions`](https://hf.co/datasets/lambdalabs/pokemon-blip-captions), it might not be a problem, it can definitely lead to memory problems when the script is used on a larger dataset. For those purposes, you would want to serialize these pre-computed representations to disk separately and load them during the fine-tuning process. Refer to [this PR](https://github.com/huggingface/diffusers/pull/4505) for a more in-depth discussion.
* The training script is compute-intensive and may not run on a consumer GPU like Tesla T4.
* The training command shown above performs intermediate quality validation in between the training epochs and logs the results to Weights and Biases. `--report_to`, `--validation_prompt`, and `--validation_epochs` are the relevant CLI arguments here.
* SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).

Expand All @@ -95,6 +95,35 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("pokemon.png")
```

### Inference in Pytorch XLA
```python
from diffusers import DiffusionPipeline
import torch
import torch_xla.core.xla_model as xm

model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(model_id)

device = xm.xla_device()
pipe.to(device)

prompt = "A pokemon with green eyes and red legs."
start = time()
image = pipe(prompt, num_inference_steps=inference_steps).images[0]
print(f'Compilation time is {time()-start} sec')
image.save("pokemon.png")

start = time()
image = pipe(prompt, num_inference_steps=inference_steps).images[0]
print(f'Inference time is {time()-start} sec after compilation')
```

Note: There is a warmup step in PyTorch XLA. This takes longer because of
compilation and optimization. To see the real benefits of Pytorch XLA and
speedup, we need to call the pipe again on the input with the same length
as the original prompt to reuse the optimized graph and get the performance
Comment on lines +123 to +124
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the prompt length really matters here, as embeddings always have the same shape?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great question. With the current setup we need the same length for the input for XLA, otherwise it will recompile the whole graph. We can potentially cut the graph with xm.mark_step after embeddings are calculated, but it is not address in this PR.

boost.

## LoRA training example for Stable Diffusion XL (SDXL)

Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
Expand All @@ -112,7 +141,7 @@ on consumer GPUs like Tesla T4, Tesla V100.

### Training

First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables and, optionally, the `VAE_NAME` variable. Here, we will use [Stable Diffusion XL 1.0-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables and, optionally, the `VAE_NAME` variable. Here, we will use [Stable Diffusion XL 1.0-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).

**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___**

Expand All @@ -122,7 +151,7 @@ export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
```

For this example we want to directly store the trained LoRA embeddings on the Hub, so
For this example we want to directly store the trained LoRA embeddings on the Hub, so
we need to be logged in and add the `--push_to_hub` flag.

```bash
Expand All @@ -149,7 +178,7 @@ accelerate launch train_text_to_image_lora_sdxl.py \

The above command will also run inference as fine-tuning progresses and log the results to Weights and Biases.

**Notes**:
**Notes**:

* SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).

Expand Down Expand Up @@ -178,7 +207,7 @@ accelerate launch train_text_to_image_lora_sdxl.py \

### Inference

Once you have trained a model using above command, the inference can be done simply using the `DiffusionPipeline` after loading the trained LoRA weights. You
Once you have trained a model using above command, the inference can be done simply using the `DiffusionPipeline` after loading the trained LoRA weights. You
need to pass the `output_dir` for loading the LoRA weights which, in this case, is `sd-pokemon-model-lora-sdxl`.

```python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
is_invisible_watermark_available,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
Expand All @@ -48,6 +49,13 @@
if is_invisible_watermark_available():
from .watermark import StableDiffusionXLWatermarker

if is_torch_xla_available():
import torch_xla.core.xla_model as xm

XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False


logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -860,7 +868,7 @@ def __call__(
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)

# 7.1 Apply denoising_end
# 8.1 Apply denoising_end
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
discrete_timestep_cutoff = int(
round(
Expand Down Expand Up @@ -908,6 +916,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)

if XLA_AVAILABLE:
xm.mark_step()

if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
is_invisible_watermark_available,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
Expand All @@ -45,6 +46,13 @@
if is_invisible_watermark_available():
from .watermark import StableDiffusionXLWatermarker

if is_torch_xla_available():
import torch_xla.core.xla_model as xm

XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False


logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -1031,6 +1039,9 @@ def denoising_value_valid(dnv):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)

if XLA_AVAILABLE:
xm.mark_step()

if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ...utils import (
deprecate,
is_invisible_watermark_available,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
Expand All @@ -47,6 +48,13 @@
if is_invisible_watermark_available():
from .watermark import StableDiffusionXLWatermarker

if is_torch_xla_available():
import torch_xla.core.xla_model as xm

XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False


logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -1355,6 +1363,9 @@ def denoising_value_valid(dnv):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)

if XLA_AVAILABLE:
xm.mark_step()

if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ...utils import (
deprecate,
is_invisible_watermark_available,
is_torch_xla_available,
logging,
replace_example_docstring,
)
Expand All @@ -44,6 +45,13 @@
if is_invisible_watermark_available():
from .watermark import StableDiffusionXLWatermarker

if is_torch_xla_available():
import torch_xla.core.xla_model as xm

XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False


logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -926,6 +934,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)

if XLA_AVAILABLE:
xm.mark_step()

if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
is_tensorboard_available,
is_torch_available,
is_torch_version,
is_torch_xla_available,
is_torchsde_available,
is_transformers_available,
is_transformers_version,
Expand Down
12 changes: 12 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@
logger.info("Disabling PyTorch because USE_TORCH is set")
_torch_available = False

_torch_xla_available = importlib.util.find_spec("torch_xla") is not None
if _torch_xla_available:
try:
_torch_xla_version = importlib_metadata.version("torch_xla")
logger.info(f"PyTorch XLA version {_torch_xla_version} available.")
except ImportError:
_torch_xla_available = False

_jax_version = "N/A"
_flax_version = "N/A"
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
Expand Down Expand Up @@ -281,6 +289,10 @@ def is_torch_available():
return _torch_available


def is_torch_xla_available():
return _torch_xla_available


def is_flax_available():
return _flax_available

Expand Down