From 1082fbdf111e15dfaa5631356e13628ab0fb1116 Mon Sep 17 00:00:00 2001 From: ssusie Date: Tue, 3 Oct 2023 00:56:12 +0000 Subject: [PATCH 1/3] Added mark_step for sdxl to run with pytorch xla. Also updated README with instructions for xla --- examples/text_to_image/README_sdxl.md | 45 +++++++++++++++---- .../pipeline_stable_diffusion_xl.py | 11 ++++- .../pipeline_stable_diffusion_xl_img2img.py | 9 ++++ .../pipeline_stable_diffusion_xl_inpaint.py | 9 ++++ ...ne_stable_diffusion_xl_instruct_pix2pix.py | 9 ++++ 5 files changed, 74 insertions(+), 9 deletions(-) diff --git a/examples/text_to_image/README_sdxl.md b/examples/text_to_image/README_sdxl.md index 4c2f92eaa8b8..75c9cb126472 100644 --- a/examples/text_to_image/README_sdxl.md +++ b/examples/text_to_image/README_sdxl.md @@ -44,7 +44,7 @@ from accelerate.utils import write_basic_config write_basic_config() ``` -When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. ### Training @@ -73,10 +73,10 @@ accelerate launch train_text_to_image_sdxl.py \ --push_to_hub ``` -**Notes**: +**Notes**: -* The `train_text_to_image_sdxl.py` script pre-computes text embeddings and the VAE encodings and keeps them in memory. While for smaller datasets like [`lambdalabs/pokemon-blip-captions`](https://hf.co/datasets/lambdalabs/pokemon-blip-captions), it might not be a problem, it can definitely lead to memory problems when the script is used on a larger dataset. For those purposes, you would want to serialize these pre-computed representations to disk separately and load them during the fine-tuning process. Refer to [this PR](https://github.com/huggingface/diffusers/pull/4505) for a more in-depth discussion. -* The training script is compute-intensive and may not run on a consumer GPU like Tesla T4. +* The `train_text_to_image_sdxl.py` script pre-computes text embeddings and the VAE encodings and keeps them in memory. While for smaller datasets like [`lambdalabs/pokemon-blip-captions`](https://hf.co/datasets/lambdalabs/pokemon-blip-captions), it might not be a problem, it can definitely lead to memory problems when the script is used on a larger dataset. For those purposes, you would want to serialize these pre-computed representations to disk separately and load them during the fine-tuning process. Refer to [this PR](https://github.com/huggingface/diffusers/pull/4505) for a more in-depth discussion. +* The training script is compute-intensive and may not run on a consumer GPU like Tesla T4. * The training command shown above performs intermediate quality validation in between the training epochs and logs the results to Weights and Biases. `--report_to`, `--validation_prompt`, and `--validation_epochs` are the relevant CLI arguments here. * SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). @@ -95,6 +95,35 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] image.save("pokemon.png") ``` +### Inference in Pytorch XLA +```python +from diffusers import DiffusionPipeline +import torch +import torch_xla.core.xla_model as xm + +model_id = "stabilityai/stable-diffusion-xl-base-1.0" +pipe = DiffusionPipeline.from_pretrained(model_id) + +device = xm.xla_device() +pipe.to(device) + +prompt = "A pokemon with green eyes and red legs." +start = time() +image = pipe(prompt, num_inference_steps=inference_steps).images[0] +print(f'Compilation time is {time()-start} sec') +image.save("pokemon.png") + +start = time() +image = pipe(prompt, num_inference_steps=inference_steps).images[0] +print(f'Inference time is {time()-start} sec after compilation') +``` + +Note: There is a warmup step in PyTorch XLA. This takes longer because of +compilation and optimization. To see the real benefits of Pytorch XLA and +speedup, we need to call the pipe again on the input with the same length +as the original prompt to reuse the optimized graph and get the performance +boost. + ## LoRA training example for Stable Diffusion XL (SDXL) Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. @@ -112,7 +141,7 @@ on consumer GPUs like Tesla T4, Tesla V100. ### Training -First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables and, optionally, the `VAE_NAME` variable. Here, we will use [Stable Diffusion XL 1.0-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions). +First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables and, optionally, the `VAE_NAME` variable. Here, we will use [Stable Diffusion XL 1.0-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions). **___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___** @@ -122,7 +151,7 @@ export VAE_NAME="madebyollin/sdxl-vae-fp16-fix" export DATASET_NAME="lambdalabs/pokemon-blip-captions" ``` -For this example we want to directly store the trained LoRA embeddings on the Hub, so +For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag. ```bash @@ -149,7 +178,7 @@ accelerate launch train_text_to_image_lora_sdxl.py \ The above command will also run inference as fine-tuning progresses and log the results to Weights and Biases. -**Notes**: +**Notes**: * SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). @@ -178,7 +207,7 @@ accelerate launch train_text_to_image_lora_sdxl.py \ ### Inference -Once you have trained a model using above command, the inference can be done simply using the `DiffusionPipeline` after loading the trained LoRA weights. You +Once you have trained a model using above command, the inference can be done simply using the `DiffusionPipeline` after loading the trained LoRA weights. You need to pass the `output_dir` for loading the LoRA weights which, in this case, is `sd-pokemon-model-lora-sdxl`. ```python diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 98eeb8e3448c..b12d5c5f1166 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -48,6 +48,12 @@ if is_invisible_watermark_available(): from .watermark import StableDiffusionXLWatermarker +try: + import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True +except: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -832,7 +838,7 @@ def __call__( # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - # 7.1 Apply denoising_end + # 8.1 Apply denoising_end if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: discrete_timestep_cutoff = int( round( @@ -880,6 +886,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 9612a8e28f8e..ef8f2dc2670d 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -45,6 +45,12 @@ if is_invisible_watermark_available(): from .watermark import StableDiffusionXLWatermarker +try: + import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True +except: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1031,6 +1037,9 @@ def denoising_value_valid(dnv): step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 209c9b339aec..623b6f0b061b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -47,6 +47,12 @@ if is_invisible_watermark_available(): from .watermark import StableDiffusionXLWatermarker +try: + import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True +except: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1355,6 +1361,9 @@ def denoising_value_valid(dnv): step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 6fd1be88b284..14cc0ce39c88 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -44,6 +44,12 @@ if is_invisible_watermark_available(): from .watermark import StableDiffusionXLWatermarker +try: + import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True +except: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -926,6 +932,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast From a6941480b45c8b7ed55d28ea0663a23605481d98 Mon Sep 17 00:00:00 2001 From: ssusie Date: Wed, 4 Oct 2023 21:42:21 +0000 Subject: [PATCH 2/3] adding soft dependency on torch_xla --- .../pipeline_stable_diffusion_xl.py | 5 +++-- .../pipeline_stable_diffusion_xl_img2img.py | 5 +++-- .../pipeline_stable_diffusion_xl_inpaint.py | 5 +++-- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 5 +++-- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 12 ++++++++++++ 6 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index b12d5c5f1166..9efaaf417e0b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -35,6 +35,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_invisible_watermark_available, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -48,10 +49,10 @@ if is_invisible_watermark_available(): from .watermark import StableDiffusionXLWatermarker -try: +if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True -except: +else: XLA_AVAILABLE = False diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index ef8f2dc2670d..67582cf3cdd8 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -32,6 +32,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_invisible_watermark_available, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -45,10 +46,10 @@ if is_invisible_watermark_available(): from .watermark import StableDiffusionXLWatermarker -try: +if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True -except: +else: XLA_AVAILABLE = False diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 623b6f0b061b..618c1e0248e8 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -34,6 +34,7 @@ from ...utils import ( deprecate, is_invisible_watermark_available, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -47,10 +48,10 @@ if is_invisible_watermark_available(): from .watermark import StableDiffusionXLWatermarker -try: +if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True -except: +else: XLA_AVAILABLE = False diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 14cc0ce39c88..ed203a6b6ffc 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -33,6 +33,7 @@ from ...utils import ( deprecate, is_invisible_watermark_available, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -44,10 +45,10 @@ if is_invisible_watermark_available(): from .watermark import StableDiffusionXLWatermarker -try: +if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True -except: +else: XLA_AVAILABLE = False diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index b0e6a5169c7e..76d2177b0e56 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -72,6 +72,7 @@ is_scipy_available, is_tensorboard_available, is_torch_available, + is_torch_xla_available, is_torch_version, is_torchsde_available, is_transformers_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index b3fc086363e3..b3278af2f6a5 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -64,6 +64,14 @@ logger.info("Disabling PyTorch because USE_TORCH is set") _torch_available = False +_torch_xla_available = importlib.util.find_spec("torch_xla") is not None +if _torch_xla_available: + try: + _torch_xla_version = importlib_metadata.version("torch_xla") + logger.info(f"PyTorch XLA version {_torch_xla_version} available.") + except ImportError: + _torch_xla_available = False + _jax_version = "N/A" _flax_version = "N/A" if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: @@ -281,6 +289,10 @@ def is_torch_available(): return _torch_available +def is_torch_xla_available(): + return _torch_xla_available + + def is_flax_available(): return _flax_available From 4a897eb2c492a65b5e4526643de79255041fbac8 Mon Sep 17 00:00:00 2001 From: ssusie Date: Thu, 5 Oct 2023 18:01:31 +0000 Subject: [PATCH 3/3] fix some styling --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 1 + .../stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py | 1 + .../stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py | 1 + .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 1 + src/diffusers/utils/__init__.py | 2 +- 5 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 9efaaf417e0b..127ecc4b5e30 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -51,6 +51,7 @@ if is_torch_xla_available(): import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True else: XLA_AVAILABLE = False diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 67582cf3cdd8..a5fb134f9913 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -48,6 +48,7 @@ if is_torch_xla_available(): import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True else: XLA_AVAILABLE = False diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 618c1e0248e8..a6e0531eae3a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -50,6 +50,7 @@ if is_torch_xla_available(): import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True else: XLA_AVAILABLE = False diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index ed203a6b6ffc..797b6c8af0f1 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -47,6 +47,7 @@ if is_torch_xla_available(): import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True else: XLA_AVAILABLE = False diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 76d2177b0e56..128ebb1fb737 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -72,8 +72,8 @@ is_scipy_available, is_tensorboard_available, is_torch_available, - is_torch_xla_available, is_torch_version, + is_torch_xla_available, is_torchsde_available, is_transformers_available, is_transformers_version,