diff --git a/examples/text_to_image/README_sdxl.md b/examples/text_to_image/README_sdxl.md index 4c2f92eaa8b8..75c9cb126472 100644 --- a/examples/text_to_image/README_sdxl.md +++ b/examples/text_to_image/README_sdxl.md @@ -44,7 +44,7 @@ from accelerate.utils import write_basic_config write_basic_config() ``` -When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. ### Training @@ -73,10 +73,10 @@ accelerate launch train_text_to_image_sdxl.py \ --push_to_hub ``` -**Notes**: +**Notes**: -* The `train_text_to_image_sdxl.py` script pre-computes text embeddings and the VAE encodings and keeps them in memory. While for smaller datasets like [`lambdalabs/pokemon-blip-captions`](https://hf.co/datasets/lambdalabs/pokemon-blip-captions), it might not be a problem, it can definitely lead to memory problems when the script is used on a larger dataset. For those purposes, you would want to serialize these pre-computed representations to disk separately and load them during the fine-tuning process. Refer to [this PR](https://github.com/huggingface/diffusers/pull/4505) for a more in-depth discussion. -* The training script is compute-intensive and may not run on a consumer GPU like Tesla T4. +* The `train_text_to_image_sdxl.py` script pre-computes text embeddings and the VAE encodings and keeps them in memory. While for smaller datasets like [`lambdalabs/pokemon-blip-captions`](https://hf.co/datasets/lambdalabs/pokemon-blip-captions), it might not be a problem, it can definitely lead to memory problems when the script is used on a larger dataset. For those purposes, you would want to serialize these pre-computed representations to disk separately and load them during the fine-tuning process. Refer to [this PR](https://github.com/huggingface/diffusers/pull/4505) for a more in-depth discussion. +* The training script is compute-intensive and may not run on a consumer GPU like Tesla T4. * The training command shown above performs intermediate quality validation in between the training epochs and logs the results to Weights and Biases. `--report_to`, `--validation_prompt`, and `--validation_epochs` are the relevant CLI arguments here. * SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). @@ -95,6 +95,35 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] image.save("pokemon.png") ``` +### Inference in Pytorch XLA +```python +from diffusers import DiffusionPipeline +import torch +import torch_xla.core.xla_model as xm + +model_id = "stabilityai/stable-diffusion-xl-base-1.0" +pipe = DiffusionPipeline.from_pretrained(model_id) + +device = xm.xla_device() +pipe.to(device) + +prompt = "A pokemon with green eyes and red legs." +start = time() +image = pipe(prompt, num_inference_steps=inference_steps).images[0] +print(f'Compilation time is {time()-start} sec') +image.save("pokemon.png") + +start = time() +image = pipe(prompt, num_inference_steps=inference_steps).images[0] +print(f'Inference time is {time()-start} sec after compilation') +``` + +Note: There is a warmup step in PyTorch XLA. This takes longer because of +compilation and optimization. To see the real benefits of Pytorch XLA and +speedup, we need to call the pipe again on the input with the same length +as the original prompt to reuse the optimized graph and get the performance +boost. + ## LoRA training example for Stable Diffusion XL (SDXL) Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. @@ -112,7 +141,7 @@ on consumer GPUs like Tesla T4, Tesla V100. ### Training -First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables and, optionally, the `VAE_NAME` variable. Here, we will use [Stable Diffusion XL 1.0-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions). +First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables and, optionally, the `VAE_NAME` variable. Here, we will use [Stable Diffusion XL 1.0-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions). **___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___** @@ -122,7 +151,7 @@ export VAE_NAME="madebyollin/sdxl-vae-fp16-fix" export DATASET_NAME="lambdalabs/pokemon-blip-captions" ``` -For this example we want to directly store the trained LoRA embeddings on the Hub, so +For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag. ```bash @@ -149,7 +178,7 @@ accelerate launch train_text_to_image_lora_sdxl.py \ The above command will also run inference as fine-tuning progresses and log the results to Weights and Biases. -**Notes**: +**Notes**: * SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). @@ -178,7 +207,7 @@ accelerate launch train_text_to_image_lora_sdxl.py \ ### Inference -Once you have trained a model using above command, the inference can be done simply using the `DiffusionPipeline` after loading the trained LoRA weights. You +Once you have trained a model using above command, the inference can be done simply using the `DiffusionPipeline` after loading the trained LoRA weights. You need to pass the `output_dir` for loading the LoRA weights which, in this case, is `sd-pokemon-model-lora-sdxl`. ```python diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 4c1bd857d7cb..61856d16a197 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -35,6 +35,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_invisible_watermark_available, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -48,6 +49,13 @@ if is_invisible_watermark_available(): from .watermark import StableDiffusionXLWatermarker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -860,7 +868,7 @@ def __call__( # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - # 7.1 Apply denoising_end + # 8.1 Apply denoising_end if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: discrete_timestep_cutoff = int( round( @@ -908,6 +916,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 9612a8e28f8e..a5fb134f9913 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -32,6 +32,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_invisible_watermark_available, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -45,6 +46,13 @@ if is_invisible_watermark_available(): from .watermark import StableDiffusionXLWatermarker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1031,6 +1039,9 @@ def denoising_value_valid(dnv): step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 209c9b339aec..a6e0531eae3a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -34,6 +34,7 @@ from ...utils import ( deprecate, is_invisible_watermark_available, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -47,6 +48,13 @@ if is_invisible_watermark_available(): from .watermark import StableDiffusionXLWatermarker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1355,6 +1363,9 @@ def denoising_value_valid(dnv): step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 6fd1be88b284..797b6c8af0f1 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -33,6 +33,7 @@ from ...utils import ( deprecate, is_invisible_watermark_available, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -44,6 +45,13 @@ if is_invisible_watermark_available(): from .watermark import StableDiffusionXLWatermarker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -926,6 +934,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index b0e6a5169c7e..128ebb1fb737 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -73,6 +73,7 @@ is_tensorboard_available, is_torch_available, is_torch_version, + is_torch_xla_available, is_torchsde_available, is_transformers_available, is_transformers_version, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index b3fc086363e3..b3278af2f6a5 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -64,6 +64,14 @@ logger.info("Disabling PyTorch because USE_TORCH is set") _torch_available = False +_torch_xla_available = importlib.util.find_spec("torch_xla") is not None +if _torch_xla_available: + try: + _torch_xla_version = importlib_metadata.version("torch_xla") + logger.info(f"PyTorch XLA version {_torch_xla_version} available.") + except ImportError: + _torch_xla_available = False + _jax_version = "N/A" _flax_version = "N/A" if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: @@ -281,6 +289,10 @@ def is_torch_available(): return _torch_available +def is_torch_xla_available(): + return _torch_xla_available + + def is_flax_available(): return _flax_available