From 63b38841c039e035dcf40e627680ea14b9555507 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 9 Feb 2023 11:34:12 +0100 Subject: [PATCH 01/25] add sdpa processor --- src/diffusers/models/cross_attention.py | 48 ++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 2ea2e7be58e8..b8edc366882f 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -99,7 +99,11 @@ def __init__( self.to_out.append(nn.Dropout(dropout)) # set attention processor - processor = processor if processor is not None else CrossAttnProcessor() + if hasattr(F, "scaled_dot_product_attention"): + defulti_processor = TorchAttentionProcessor() + else: + defulti_processor = CrossAttnProcessor() + processor = processor if processor is not None else defulti_processor self.set_processor(processor) def set_use_memory_efficient_attention_xformers( @@ -436,6 +440,48 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No return hidden_states +class TorchAttentionProcessor: + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + batch_size, sequence_length, inner_dim = query.shape + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + hidden_states = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + class LoRAXFormersCrossAttnProcessor(nn.Module): def __init__(self, hidden_size, cross_attention_dim, rank=4): super().__init__() From ae47101cd7be52c0bade2211a6e7a3191dfdf499 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 9 Feb 2023 12:22:15 +0100 Subject: [PATCH 02/25] don't use it by default --- src/diffusers/models/cross_attention.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index b8edc366882f..314d124634da 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -99,11 +99,7 @@ def __init__( self.to_out.append(nn.Dropout(dropout)) # set attention processor - if hasattr(F, "scaled_dot_product_attention"): - defulti_processor = TorchAttentionProcessor() - else: - defulti_processor = CrossAttnProcessor() - processor = processor if processor is not None else defulti_processor + processor = processor if processor is not None else CrossAttnProcessor self.set_processor(processor) def set_use_memory_efficient_attention_xformers( @@ -441,6 +437,11 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No class TorchAttentionProcessor: + def __init__(self): + # throw an error if scaled dot product attention is not available + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("TorchAttentionProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape From 92065dc8c8def9c655d28b2d1f51d50aac4ee71c Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 9 Feb 2023 12:27:14 +0100 Subject: [PATCH 03/25] add some checks and style --- .../text_to_image/train_text_to_image_lora.py | 4 ++-- src/diffusers/models/cross_attention.py | 20 ++++++++----------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index a3c5bef73a95..abc535594d8c 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -418,9 +418,9 @@ def main(): # freeze parameters of models to save more memory unet.requires_grad_(False) vae.requires_grad_(False) - + text_encoder.requires_grad_(False) - + # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 314d124634da..56831b5a27e6 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -438,14 +438,16 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No class TorchAttentionProcessor: def __init__(self): - # throw an error if scaled dot product attention is not available if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("TorchAttentionProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - + raise ImportError( + "TorchAttentionProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = hidden_states.shape + batch_size, _, inner_dim = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if attention_mask is not None: + raise NotImplementedError("Attention mask is not supported yet.") query = attn.to_q(hidden_states) @@ -457,7 +459,6 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) - batch_size, sequence_length, inner_dim = query.shape head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) @@ -465,12 +466,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No # the output of sdp = (batch, num_heads, seq_len, head_dim) hidden_states = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=None, - dropout_p=0.0, - is_causal=False, + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) From 9656f01e2b46e758934a6073b5faf004692a5100 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 9 Feb 2023 12:34:27 +0100 Subject: [PATCH 04/25] typo --- src/diffusers/models/cross_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 56831b5a27e6..e02f65920cdb 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -99,7 +99,7 @@ def __init__( self.to_out.append(nn.Dropout(dropout)) # set attention processor - processor = processor if processor is not None else CrossAttnProcessor + processor = processor if processor is not None else CrossAttnProcessor() self.set_processor(processor) def set_use_memory_efficient_attention_xformers( From 00849d0d90edda5546a332b7e3e0d9be2f73e063 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 14 Feb 2023 14:49:02 +0100 Subject: [PATCH 05/25] support torch sdpa in dreambooth example --- examples/dreambooth/train_dreambooth.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index ce8f0a52b8a1..0939da362013 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -41,6 +41,7 @@ import diffusers from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.models.cross_attention import TorchAttentionProcessor from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available @@ -321,6 +322,11 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument( + "--enable_torch_sdpa", + action="store_true", + help="Whether or not to use PyTorch's efficient SDPA implementation.", + ) parser.add_argument( "--set_grads_to_none", action="store_true", @@ -340,6 +346,12 @@ def parse_args(input_args=None): if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank + # assert only one of enable_xformers_memory_efficient_attention and enable_torch_sdpa is set + if args.enable_xformers_memory_efficient_attention and args.enable_torch_sdpa: + raise ValueError( + "You can only use one of --enable_xformers_memory_efficient_attention and --enable_torch_sdpa at the same time." + ) + if args.with_prior_preservation: if args.class_data_dir is None: raise ValueError("You must specify a data directory for class images.") @@ -650,6 +662,8 @@ def load_model_hook(models, input_dir): unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") + elif args.enable_torch_sdpa: + unet.set_attn_processor(TorchAttentionProcessor()) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() From e25cecddb9b7255c6e0348e7d8f1d60b1053dda5 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 15 Feb 2023 14:43:38 +0100 Subject: [PATCH 06/25] use torch attn proc by default when available --- examples/dreambooth/train_dreambooth.py | 14 -------------- src/diffusers/models/cross_attention.py | 4 +++- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 0939da362013..ce8f0a52b8a1 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -41,7 +41,6 @@ import diffusers from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel -from diffusers.models.cross_attention import TorchAttentionProcessor from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available @@ -322,11 +321,6 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) - parser.add_argument( - "--enable_torch_sdpa", - action="store_true", - help="Whether or not to use PyTorch's efficient SDPA implementation.", - ) parser.add_argument( "--set_grads_to_none", action="store_true", @@ -346,12 +340,6 @@ def parse_args(input_args=None): if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank - # assert only one of enable_xformers_memory_efficient_attention and enable_torch_sdpa is set - if args.enable_xformers_memory_efficient_attention and args.enable_torch_sdpa: - raise ValueError( - "You can only use one of --enable_xformers_memory_efficient_attention and --enable_torch_sdpa at the same time." - ) - if args.with_prior_preservation: if args.class_data_dir is None: raise ValueError("You must specify a data directory for class images.") @@ -662,8 +650,6 @@ def load_model_hook(models, input_dir): unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") - elif args.enable_torch_sdpa: - unet.set_attn_processor(TorchAttentionProcessor()) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 1032b57175c5..9019d51ef5ec 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -99,7 +99,9 @@ def __init__( self.to_out.append(nn.Dropout(dropout)) # set attention processor - processor = processor if processor is not None else CrossAttnProcessor() + # We use the TorchAttentionProcessor by default when torch2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + processor = TorchAttentionProcessor if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor self.set_processor(processor) def set_use_memory_efficient_attention_xformers( From 23a40eb7927ae9546dd74e2ec34f962b55b11598 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 15 Feb 2023 14:49:58 +0100 Subject: [PATCH 07/25] typo --- src/diffusers/models/cross_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 9019d51ef5ec..8790e7e1a15d 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -101,7 +101,7 @@ def __init__( # set attention processor # We use the TorchAttentionProcessor by default when torch2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - processor = TorchAttentionProcessor if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor + processor = TorchAttentionProcessor() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() self.set_processor(processor) def set_use_memory_efficient_attention_xformers( From 640f875a17f9b65604a199b989b6cc3a4e5dd3a3 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 17 Feb 2023 10:02:57 +0100 Subject: [PATCH 08/25] add attn mask --- src/diffusers/models/cross_attention.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 8790e7e1a15d..05c000b5bf87 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -473,10 +473,11 @@ def __init__(self): ) def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, _, inner_dim = hidden_states.shape + batch_size, sequence_length, inner_dim = hidden_states.shape if attention_mask is not None: - raise NotImplementedError("Attention mask is not supported yet.") + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) query = attn.to_q(hidden_states) @@ -495,7 +496,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No # the output of sdp = (batch, num_heads, seq_len, head_dim) hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) From aaa83eca468c282aa839b9ad453db24bcde61526 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 17 Feb 2023 10:06:38 +0100 Subject: [PATCH 09/25] fix naming --- src/diffusers/models/cross_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 05c000b5bf87..9d488533717f 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -99,9 +99,9 @@ def __init__( self.to_out.append(nn.Dropout(dropout)) # set attention processor - # We use the TorchAttentionProcessor by default when torch2.x is used which uses + # We use the Torch2AttnProcessor by default when torch2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - processor = TorchAttentionProcessor() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() + processor = Torch2AttnProcessor() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() self.set_processor(processor) def set_use_memory_efficient_attention_xformers( @@ -465,7 +465,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No return hidden_states -class TorchAttentionProcessor: +class Torch2AttnProcessor: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( From e28f299a3e1b99ed455ba0e231f2c9804f5bc425 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 17 Feb 2023 10:15:11 +0100 Subject: [PATCH 10/25] being doc --- docs/source/en/optimization/torch2.0.mdx | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 docs/source/en/optimization/torch2.0.mdx diff --git a/docs/source/en/optimization/torch2.0.mdx b/docs/source/en/optimization/torch2.0.mdx new file mode 100644 index 000000000000..3579d7a2cbef --- /dev/null +++ b/docs/source/en/optimization/torch2.0.mdx @@ -0,0 +1,18 @@ + + +# Torch2.0 support in diffusers. + +Starting from version `0.13.0` Diffusers supports latest optimization from the upcoming PyTorch 2.0 release. These include: +1. Support for native flash and memory efficient attention (from xFormers) without any extra dependacnies +2. `torch.compile` support for compiling the individual models for extra performance boost. + From 4d6458c6b053cc373eef8af4a60694caefef299b Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 17 Feb 2023 10:44:25 +0100 Subject: [PATCH 11/25] doc --- docs/source/en/optimization/torch2.0.mdx | 70 +++++++++++++++++++++++- 1 file changed, 67 insertions(+), 3 deletions(-) diff --git a/docs/source/en/optimization/torch2.0.mdx b/docs/source/en/optimization/torch2.0.mdx index 3579d7a2cbef..604574cf1641 100644 --- a/docs/source/en/optimization/torch2.0.mdx +++ b/docs/source/en/optimization/torch2.0.mdx @@ -12,7 +12,71 @@ specific language governing permissions and limitations under the License. # Torch2.0 support in diffusers. -Starting from version `0.13.0` Diffusers supports latest optimization from the upcoming PyTorch 2.0 release. These include: -1. Support for native flash and memory efficient attention (from xFormers) without any extra dependacnies -2. `torch.compile` support for compiling the individual models for extra performance boost. +Starting from version `0.13.0` Diffusers supports latest optimization from the upcoming [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/) release. These include: +1. Support for native flash and memory efficient attention (from xFormers) without any extra dependacnies. +2. [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) support for compiling individual models for extra performance boost. + +## Installation +To benift from the the native effcient attention and `torch.compile`, we will need to install the nightly version of PyTorch as the stable for version is yet to be release. The first step is to install CUDA11.7 as torch2.0 does not support the previsou version. Once CUDA11.7 installed, +torch nightly can be installed using + +```python +pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu117 +``` + +## Using efficient attention and torch.compile. + + +1. **Efficient Attention** + + The efficient attention is implemented via the `torch.nn.functional.scaled_dot_product_attention` function which automatically enables flash/memory efficient attention, depending on the input and the GPU type. This is same as the `memory_efficient_attention` from [xFormers](https://github.com/facebookresearch/xformers) but built natively into torch. For more information refer to torch [docs](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention). + + The efficient attention will be enaled by default if torch2.0 is installed and if `torch.nn.functional.scaled_dot_product_attention` is available. To use it just install torch2.0 as suggested above and simply use the pipelines. For example: + + ```Python + import torch + from diffusers import StableDiffusionPipeline + + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + pipe = pipe.to("cuda") + + prompt = "a photo of an astronaut riding a horse on mars" + image = pipe(prompt).images[0] + ``` + + If you want to enable it explicitly (which is not required), you can do so as shown below. + + ```Python + import torch + from diffusers import StableDiffusionPipeline + from diffusers.models.cross_attention import Torch2AttnProcessor + + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") + pipe.unet.set_attn_processor(Torch2AttnProcessor()) + + prompt = "a photo of an astronaut riding a horse on mars" + image = pipe(prompt).images[0] + ``` + + This should be as fast and memory efficient as `xFormers`. + + +2. **torch.compile** + + To get further speed-up we can make use of `torch2.0`. To do so we simply wrap our `unet` with `torch.compile`. For more infomation and different options, refer to the + [torch compile docs](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). + + ```python + import torch + from diffusers import StableDiffusionPipeline + + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None).to("cuda") + pipe.unet = torch.compile(pipe.unet) + + batch_size = 10 + prompt = "A photo of an astronaut riding a horse on marse." + images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images + ``` + + Depending on the type of GPUs it can give between 2-9% spee-up over efficient attention. But note that as of now the speed-up is mostly noticable on high-end GPUs such as A100. \ No newline at end of file From 0b7eac8efc78127e40d32cc1adfef21dc3c05254 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 17 Feb 2023 10:48:51 +0100 Subject: [PATCH 12/25] Apply suggestions from code review --- docs/source/en/optimization/torch2.0.mdx | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/en/optimization/torch2.0.mdx b/docs/source/en/optimization/torch2.0.mdx index 604574cf1641..1edb2c20a79c 100644 --- a/docs/source/en/optimization/torch2.0.mdx +++ b/docs/source/en/optimization/torch2.0.mdx @@ -12,16 +12,16 @@ specific language governing permissions and limitations under the License. # Torch2.0 support in diffusers. -Starting from version `0.13.0` Diffusers supports latest optimization from the upcoming [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/) release. These include: -1. Support for native flash and memory efficient attention (from xFormers) without any extra dependacnies. +Starting from version `0.13.0`, Diffusers supports the latest optimization from the upcoming [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/) release. These include: +1. Support for native flash and memory-efficient attention (from xFormers) without any extra dependencies. 2. [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) support for compiling individual models for extra performance boost. ## Installation -To benift from the the native effcient attention and `torch.compile`, we will need to install the nightly version of PyTorch as the stable for version is yet to be release. The first step is to install CUDA11.7 as torch2.0 does not support the previsou version. Once CUDA11.7 installed, +To benefit from the native efficient attention and `torch.compile`, we will need to install the nightly version of PyTorch as the stable version is yet to be released. The first step is to install CUDA11.7, as torch2.0 does not support the previous version. Once CUDA11.7 is installed, torch nightly can be installed using -```python +```bash pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu117 ``` @@ -30,9 +30,9 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl 1. **Efficient Attention** - The efficient attention is implemented via the `torch.nn.functional.scaled_dot_product_attention` function which automatically enables flash/memory efficient attention, depending on the input and the GPU type. This is same as the `memory_efficient_attention` from [xFormers](https://github.com/facebookresearch/xformers) but built natively into torch. For more information refer to torch [docs](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention). + The efficient attention is implemented via the `torch.nn.functional.scaled_dot_product_attention` function, which automatically enables flash/memory efficient attention, depending on the input and the GPU type. This is the same as the `memory_efficient_attention` from [xFormers](https://github.com/facebookresearch/xformers) but built natively into the torch. For more information, refer to torch [docs](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention). - The efficient attention will be enaled by default if torch2.0 is installed and if `torch.nn.functional.scaled_dot_product_attention` is available. To use it just install torch2.0 as suggested above and simply use the pipelines. For example: + The efficient attention will be enabled by default if torch2.0 is installed and if `torch.nn.functional.scaled_dot_product_attention` is available. To use it, you can just install torch2.0 as suggested above and use the pipeline. For example: ```Python import torch @@ -64,7 +64,7 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl 2. **torch.compile** - To get further speed-up we can make use of `torch2.0`. To do so we simply wrap our `unet` with `torch.compile`. For more infomation and different options, refer to the + To get further speed up, we can use the new `torch.compile` feature. To do so, we wrap our `unet` with `torch.compile`. For more information and different options, refer to the [torch compile docs](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). ```python From 907869b650cf098d26b5db94890416060b2a3911 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 17 Feb 2023 10:52:04 +0100 Subject: [PATCH 13/25] polish --- docs/source/en/optimization/torch2.0.mdx | 8 +++++--- src/diffusers/models/cross_attention.py | 2 ++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/source/en/optimization/torch2.0.mdx b/docs/source/en/optimization/torch2.0.mdx index 1edb2c20a79c..8f4bd1e76635 100644 --- a/docs/source/en/optimization/torch2.0.mdx +++ b/docs/source/en/optimization/torch2.0.mdx @@ -71,12 +71,14 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl import torch from diffusers import StableDiffusionPipeline - pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None).to("cuda") + pipe = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None + ).to("cuda") pipe.unet = torch.compile(pipe.unet) batch_size = 10 - prompt = "A photo of an astronaut riding a horse on marse." - images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images + prompt = "A photo of an astronaut riding a horse on marse." + images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images ``` Depending on the type of GPUs it can give between 2-9% spee-up over efficient attention. But note that as of now the speed-up is mostly noticable on high-end GPUs such as A100. \ No newline at end of file diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 9d488533717f..f4a728aaf4d6 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -477,6 +477,8 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) query = attn.to_q(hidden_states) From 8be0286e0a5f2fcd678217a956dee1131a8339cc Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 17 Feb 2023 11:01:01 +0100 Subject: [PATCH 14/25] torctree --- docs/source/en/_toctree.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 2f8cf19fea83..136d2ba759aa 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -55,6 +55,8 @@ - sections: - local: optimization/fp16 title: Memory and Speed + - local: optimization/torch2.0 + title: Torch2.0 support. - local: optimization/xformers title: xFormers - local: optimization/onnx From 967fed25c98cf0d5a6a1e0dadc449d65217aad35 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 17 Feb 2023 11:18:56 +0100 Subject: [PATCH 15/25] Apply suggestions from code review Co-authored-by: Sayak Paul Co-authored-by: Patrick von Platen --- docs/source/en/optimization/torch2.0.mdx | 8 ++++---- src/diffusers/models/cross_attention.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/en/optimization/torch2.0.mdx b/docs/source/en/optimization/torch2.0.mdx index 8f4bd1e76635..4a29d2ebd8ee 100644 --- a/docs/source/en/optimization/torch2.0.mdx +++ b/docs/source/en/optimization/torch2.0.mdx @@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Torch2.0 support in diffusers. +# Torch2.0 support in Diffusers Starting from version `0.13.0`, Diffusers supports the latest optimization from the upcoming [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/) release. These include: 1. Support for native flash and memory-efficient attention (from xFormers) without any extra dependencies. @@ -30,9 +30,9 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl 1. **Efficient Attention** - The efficient attention is implemented via the `torch.nn.functional.scaled_dot_product_attention` function, which automatically enables flash/memory efficient attention, depending on the input and the GPU type. This is the same as the `memory_efficient_attention` from [xFormers](https://github.com/facebookresearch/xformers) but built natively into the torch. For more information, refer to torch [docs](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention). + The efficient attention is implemented via the [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) function, which automatically enables flash/memory efficient attention, depending on the input and the GPU type. This is the same as the `memory_efficient_attention` from [xFormers](https://github.com/facebookresearch/xformers) but built natively into the torch. - The efficient attention will be enabled by default if torch2.0 is installed and if `torch.nn.functional.scaled_dot_product_attention` is available. To use it, you can just install torch2.0 as suggested above and use the pipeline. For example: + The efficient attention will be enabled by default in Diffusers if torch2.0 is installed and if `torch.nn.functional.scaled_dot_product_attention` is available. To use it, you can install torch2.0 as suggested above and use the pipeline. For example: ```Python import torch @@ -81,4 +81,4 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images ``` - Depending on the type of GPUs it can give between 2-9% spee-up over efficient attention. But note that as of now the speed-up is mostly noticable on high-end GPUs such as A100. \ No newline at end of file + Depending on the type of GPU it can give between 2-9% speed-up over efficient attention. But note that as of now the speed-up is mostly noticeable on high-end GPUs such as A100. \ No newline at end of file diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index f4a728aaf4d6..9886f36db38d 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -469,7 +469,7 @@ class Torch2AttnProcessor: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( - "TorchAttentionProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + "Torch2AttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): From 34c6cdb8431cd90646db8b8321a714e3dbbcf27b Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 17 Feb 2023 11:19:35 +0100 Subject: [PATCH 16/25] better name --- docs/source/en/optimization/torch2.0.mdx | 4 ++-- src/diffusers/models/cross_attention.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/en/optimization/torch2.0.mdx b/docs/source/en/optimization/torch2.0.mdx index 4a29d2ebd8ee..1c45f92587e1 100644 --- a/docs/source/en/optimization/torch2.0.mdx +++ b/docs/source/en/optimization/torch2.0.mdx @@ -50,10 +50,10 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl ```Python import torch from diffusers import StableDiffusionPipeline - from diffusers.models.cross_attention import Torch2AttnProcessor + from diffusers.models.cross_attention import AttnProccesor2_0 pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") - pipe.unet.set_attn_processor(Torch2AttnProcessor()) + pipe.unet.set_attn_processor(AttnProccesor2_0()) prompt = "a photo of an astronaut riding a horse on mars" image = pipe(prompt).images[0] diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 9886f36db38d..612803acc571 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -99,9 +99,9 @@ def __init__( self.to_out.append(nn.Dropout(dropout)) # set attention processor - # We use the Torch2AttnProcessor by default when torch2.x is used which uses + # We use the AttnProccesor2_0 by default when torch2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - processor = Torch2AttnProcessor() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() + processor = AttnProccesor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() self.set_processor(processor) def set_use_memory_efficient_attention_xformers( @@ -465,11 +465,11 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No return hidden_states -class Torch2AttnProcessor: +class AttnProccesor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( - "Torch2AttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + "AttnProccesor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): From 4a38b12e163f7bdf77140490304c41c0a66616cb Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 17 Feb 2023 11:21:19 +0100 Subject: [PATCH 17/25] style --- src/diffusers/models/cross_attention.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 612803acc571..48bf25d13612 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -468,9 +468,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No class AttnProccesor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "AttnProccesor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) + raise ImportError("AttnProccesor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, inner_dim = hidden_states.shape From a21bb888100bcedfed363cdd67b9a5eda1cfd958 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 17 Feb 2023 11:33:46 +0100 Subject: [PATCH 18/25] add benchamrk table --- docs/source/en/optimization/torch2.0.mdx | 55 +++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/docs/source/en/optimization/torch2.0.mdx b/docs/source/en/optimization/torch2.0.mdx index 1c45f92587e1..f556997bdc86 100644 --- a/docs/source/en/optimization/torch2.0.mdx +++ b/docs/source/en/optimization/torch2.0.mdx @@ -81,4 +81,57 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images ``` - Depending on the type of GPU it can give between 2-9% speed-up over efficient attention. But note that as of now the speed-up is mostly noticeable on high-end GPUs such as A100. \ No newline at end of file + Depending on the type of GPU it can give between 2-9% speed-up over efficient attention. But note that as of now the speed-up is mostly noticeable on high-end GPUs such as A100. + + +## Benchmark + +We conducted a simple benchmark on different GPUs to compare between vannila attention, xFormers, `torch.nn.functional.scaled_dot_product_attention` and `torch.compile+torch.nn.functional.scaled_dot_product_attention`. +For the benchmark we used the the [stable-diffusion-v1-4]() model with `torch.float16`, for 50 steps. `xFormers` benchmark is done using the `torch==1.13.1` version. The table below summarizes the result that we got. +The `Speed over xformers` columns denotes the speed-up gained over `xFormers` using the `torch.compile+torch.nn.functional.scaled_dot_product_attention`. + +As we can see `torch.nn.functional.scaled_dot_product_attention` is fast as `xFormers` (somtimes slighly faster/slower) on all the GPUs we tested. And using `torch.compile` gives further speed-up, but it's motsly noticable on the A100 GPU. + +___The time reported is in seconds.___ + +| GPU | Batch Size | Vanilla Attention | xFormers | PyTorch2.0 SDPA | SDPA + torch.compile | Speed over xformers (%) | +| --- | --- | --- | --- | --- | --- | --- | +| A100 | 10 | 12.02 | 8.7 | 8.79 | 7.89 | 9.31 | +| A100 | 16 | 18.95 | 13.57 | 13.67 | 12.25 | 9.73 | +| A100 | 32 (1) | OOM | 26.56 | 26.68 | 24.08 | 9.34 | +| A100 | 64(2) | | 52.51 | 53.03 | 47.81 | 8.95 | +| | | | | | | | +| T4 | 4 | 38.81 | 30.09 | 29.74 | 27.55 | 8.44 | +| T4 | 8 | OOM | 55.71 | 55.99 | 53.85 | 3.34 | +| T4 | 10 | OOM | 68.96 | 69.86 | 65.35 | 5.23 | +| T4 | 16 | OOM | 111.47 | 113.26 | 106.93 | 4.07 | +| | | | | | | | +| V100 | 4 | 9.84 | 8.16 | 8.09 | 7.65 | 6.25 | +| V100 | 8 | OOM | 15.62 | 15.44 | 14.59 | 6.59 | +| V100 | 10 | OOM | 19.52 | 19.28 | 18.18 | 6.86 | +| V100 | 16 | OOM | 30.29 | 29.84 | 28.22 | 6.83 | +| | | | | | | | +| A10 | 4 | 13.94 | 9.81 | 10.01 | 9.35 | 4.69 | +| A10 | 8 | 27.09 | 19 | 19.53 | 18.33 | 3.53 | +| A10 | 10 | 33.69 | 23.53 | 24.19 | 22.52 | 4.29 | +| A10 | 16 | OOM | 37.55 | 38.31 | 36.81 | 1.97 | +| A10 | 32 (1) | | 77.19 | 78.43 | 76.64 | 0.71 | +| A10 | 64 (1) | | 173.59 | 158.99 | 155.14 | 10.63 | +| | | | | | | | +| 3090 | 4 | 10.04 | 7.82 | 7.89 | 7.47 | 4.48 | +| 3090 | 8 | 19.27 | 14.97 | 15.04 | 14.22 | 5.01 | +| 3090 | 10| 24.08 | 18.7 | 18.7 | 17.69 | 5.40 | +| 3090 | 16 | OOM | 29.06 | 29.06 | 28.2 | 2.96 | +| 3090 | 32 (1) | | 58.05 | 58 | 54.88 | 5.46 | +| 3090 | 64 (1) | | 126.54 | 126.03 | 117.33 | 7.28 | +| | | | | | | | +| 3090 Ti | 4 | 9.07 | 7.14 | 7.15 | 6.81 | 4.62 | +| 3090 Ti | 8 | 17.51 | 13.65 | 13.72 | 12.99 | 4.84 | +| 3090 Ti | 10 (2) | 21.79 | 16.85 | 16.93 | 16.02 | 4.93 | +| 3090 Ti | 16 | OOM | 26.1 | 26.28 | 25.46 | 2.45 | +| 3090 Ti | 32 (1) | | 51.78 | 52.04 | 49.15 | 5.08 | +| 3090 Ti | 64 (1) | | 112.02 | 112.33 | 103.91 | 7.24 | + + +(1) Batch Size >= 32 requires enable_vae_slicing() because of https://github.com/pytorch/pytorch/issues/81665 +This is required for PyTorch 1.13.1, and also for PyTorch 2.0 and batch size of 64 From 59630d2e916ff6bde3365277cade84185729a89f Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 17 Feb 2023 11:34:40 +0100 Subject: [PATCH 19/25] Update docs/source/en/optimization/torch2.0.mdx --- docs/source/en/optimization/torch2.0.mdx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/optimization/torch2.0.mdx b/docs/source/en/optimization/torch2.0.mdx index f556997bdc86..b68297ba359e 100644 --- a/docs/source/en/optimization/torch2.0.mdx +++ b/docs/source/en/optimization/torch2.0.mdx @@ -86,11 +86,11 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl ## Benchmark -We conducted a simple benchmark on different GPUs to compare between vannila attention, xFormers, `torch.nn.functional.scaled_dot_product_attention` and `torch.compile+torch.nn.functional.scaled_dot_product_attention`. +We conducted a simple benchmark on different GPUs to compare vanilla attention, xFormers, `torch.nn.functional.scaled_dot_product_attention` and `torch.compile+torch.nn.functional.scaled_dot_product_attention`. For the benchmark we used the the [stable-diffusion-v1-4]() model with `torch.float16`, for 50 steps. `xFormers` benchmark is done using the `torch==1.13.1` version. The table below summarizes the result that we got. The `Speed over xformers` columns denotes the speed-up gained over `xFormers` using the `torch.compile+torch.nn.functional.scaled_dot_product_attention`. -As we can see `torch.nn.functional.scaled_dot_product_attention` is fast as `xFormers` (somtimes slighly faster/slower) on all the GPUs we tested. And using `torch.compile` gives further speed-up, but it's motsly noticable on the A100 GPU. +As we can see, `torch.nn.functional.scaled_dot_product_attention` is fast as `xFormers` (sometimes slightly faster/slower) on all the GPUs we tested. And using `torch.compile` gives further speed-up, but it's mostly noticeable on the A100 GPU. ___The time reported is in seconds.___ From b5c656570e40bf76b9978fb0f6658915f1875e1d Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 17 Feb 2023 11:35:55 +0100 Subject: [PATCH 20/25] up --- docs/source/en/optimization/torch2.0.mdx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/optimization/torch2.0.mdx b/docs/source/en/optimization/torch2.0.mdx index f556997bdc86..97912611aa9f 100644 --- a/docs/source/en/optimization/torch2.0.mdx +++ b/docs/source/en/optimization/torch2.0.mdx @@ -135,3 +135,5 @@ ___The time reported is in seconds.___ (1) Batch Size >= 32 requires enable_vae_slicing() because of https://github.com/pytorch/pytorch/issues/81665 This is required for PyTorch 1.13.1, and also for PyTorch 2.0 and batch size of 64 + +For more details about how this benchmark was run, please refer to [this PR](https://github.com/huggingface/diffusers/pull/2303). From c5a7bd51d8a4e1a63383952fd1cfa743d9692ce6 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 17 Feb 2023 11:38:59 +0100 Subject: [PATCH 21/25] fix example --- docs/source/en/optimization/torch2.0.mdx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/optimization/torch2.0.mdx b/docs/source/en/optimization/torch2.0.mdx index c3fb7efadcf7..7d260578e5ce 100644 --- a/docs/source/en/optimization/torch2.0.mdx +++ b/docs/source/en/optimization/torch2.0.mdx @@ -71,9 +71,9 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl import torch from diffusers import StableDiffusionPipeline - pipe = StableDiffusionPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None - ).to("cuda") + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to( + "cuda" + ) pipe.unet = torch.compile(pipe.unet) batch_size = 10 From 88efe89637da142998c97249d3aff3267f73867e Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 17 Feb 2023 11:50:18 +0100 Subject: [PATCH 22/25] check if processor is None --- src/diffusers/models/cross_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 48bf25d13612..27a515645de6 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -101,7 +101,8 @@ def __init__( # set attention processor # We use the AttnProccesor2_0 by default when torch2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - processor = AttnProccesor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() + if processor is None: + processor = AttnProccesor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() self.set_processor(processor) def set_use_memory_efficient_attention_xformers( From 238ee95490b2a29e636694d0d29acaf1899513e2 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 17 Feb 2023 11:56:09 +0100 Subject: [PATCH 23/25] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- docs/source/en/optimization/torch2.0.mdx | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/source/en/optimization/torch2.0.mdx b/docs/source/en/optimization/torch2.0.mdx index 7d260578e5ce..e69f2008d131 100644 --- a/docs/source/en/optimization/torch2.0.mdx +++ b/docs/source/en/optimization/torch2.0.mdx @@ -30,9 +30,9 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl 1. **Efficient Attention** - The efficient attention is implemented via the [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) function, which automatically enables flash/memory efficient attention, depending on the input and the GPU type. This is the same as the `memory_efficient_attention` from [xFormers](https://github.com/facebookresearch/xformers) but built natively into the torch. + Efficient attention is implemented via the [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) function, which automatically enables flash/memory efficient attention, depending on the input and the GPU type. This is the same as the `memory_efficient_attention` from [xFormers](https://github.com/facebookresearch/xformers) but built natively into PyTorch. - The efficient attention will be enabled by default in Diffusers if torch2.0 is installed and if `torch.nn.functional.scaled_dot_product_attention` is available. To use it, you can install torch2.0 as suggested above and use the pipeline. For example: + Efficient attention will be enabled by default in Diffusers if torch2.0 is installed and if `torch.nn.functional.scaled_dot_product_attention` is available. To use it, you can install torch2.0 as suggested above and use the pipeline. For example: ```Python import torch @@ -64,7 +64,7 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl 2. **torch.compile** - To get further speed up, we can use the new `torch.compile` feature. To do so, we wrap our `unet` with `torch.compile`. For more information and different options, refer to the + To get an additional speedup, we can use the new `torch.compile` feature. To do so, we wrap our `unet` with `torch.compile`. For more information and different options, refer to the [torch compile docs](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). ```python @@ -81,7 +81,9 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images ``` - Depending on the type of GPU it can give between 2-9% speed-up over efficient attention. But note that as of now the speed-up is mostly noticeable on high-end GPUs such as A100. + Depending on the type of GPU it can give between 2-9% speed-up over efficient attention. But note that as of now the speed-up is mostly noticeable on the more recent GPU architectures, such as in the A100. + + Note that compilation will also take some time to complete, so it is best suited for situations where you need to prepare your pipeline once and then perform the same type of inference operations multiple times. ## Benchmark From 3607326b6676acf2db1da90f1df2876a2b9591e8 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 17 Feb 2023 12:11:24 +0100 Subject: [PATCH 24/25] add fp32 benchmakr --- docs/source/en/optimization/torch2.0.mdx | 71 ++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 6 deletions(-) diff --git a/docs/source/en/optimization/torch2.0.mdx b/docs/source/en/optimization/torch2.0.mdx index e69f2008d131..20af6fc15f73 100644 --- a/docs/source/en/optimization/torch2.0.mdx +++ b/docs/source/en/optimization/torch2.0.mdx @@ -13,13 +13,13 @@ specific language governing permissions and limitations under the License. # Torch2.0 support in Diffusers Starting from version `0.13.0`, Diffusers supports the latest optimization from the upcoming [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/) release. These include: -1. Support for native flash and memory-efficient attention (from xFormers) without any extra dependencies. +1. Support for native flash and memory-efficient attention without any extra dependencies. 2. [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) support for compiling individual models for extra performance boost. ## Installation -To benefit from the native efficient attention and `torch.compile`, we will need to install the nightly version of PyTorch as the stable version is yet to be released. The first step is to install CUDA11.7, as torch2.0 does not support the previous version. Once CUDA11.7 is installed, -torch nightly can be installed using +To benefit from the native efficient attention and `torch.compile`, we will need to install the nightly version of PyTorch as the stable version is yet to be released. The first step is to install CUDA11.7 or CUDA11.8, +as torch2.0 does not support the previous versions. Once CUDA is installed, torch nightly can be installed using: ```bash pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu117 @@ -89,10 +89,14 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl ## Benchmark We conducted a simple benchmark on different GPUs to compare vanilla attention, xFormers, `torch.nn.functional.scaled_dot_product_attention` and `torch.compile+torch.nn.functional.scaled_dot_product_attention`. -For the benchmark we used the the [stable-diffusion-v1-4]() model with `torch.float16`, for 50 steps. `xFormers` benchmark is done using the `torch==1.13.1` version. The table below summarizes the result that we got. +For the benchmark we used the the [stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) model with 50 steps. `xFormers` benchmark is done using the `torch==1.13.1` version. The table below summarizes the result that we got. The `Speed over xformers` columns denotes the speed-up gained over `xFormers` using the `torch.compile+torch.nn.functional.scaled_dot_product_attention`. -As we can see, `torch.nn.functional.scaled_dot_product_attention` is fast as `xFormers` (sometimes slightly faster/slower) on all the GPUs we tested. And using `torch.compile` gives further speed-up, but it's mostly noticeable on the A100 GPU. + +### FP16 benchmark + +The table below shows the benchmark results for inference using `fp16`. As we can see, `torch.nn.functional.scaled_dot_product_attention` is as fast as `xFormers` (sometimes slightly faster/slower) on all the GPUs we tested. +And using `torch.compile` gives further speed-up up to 10% over `xFormers`, but it's mostly noticeable on the A100 GPU. ___The time reported is in seconds.___ @@ -135,7 +139,62 @@ ___The time reported is in seconds.___ | 3090 Ti | 64 (1) | | 112.02 | 112.33 | 103.91 | 7.24 | + +### FP32 benchmark + +The table below shows the benchmark results for inference using `fp32`. As we can see, `torch.nn.functional.scaled_dot_product_attention` is as fast as `xFormers` (sometimes slightly faster/slower) on all the GPUs we tested. +Using `torch.compile` with efficient attention gives up to 18% performance improvement over `xFormers` in Ampere cards, and up to 20% over vanilla attention. + +| GPU | Batch Size | Vanilla Attention | xFormers | PyTorch2.0 SDPA | SDPA + torch.compile | Speed over xformers (%) | Speed over vanilla (%) | +| --- | --- | --- | --- | --- | --- | --- | --- | +| A100 | 4 | 16.56 | 12.42 | 12.2 | 11.84 | 4.67 | 28.50 | +| A100 | 10 | OOM | 29.93 | 29.44 | 28.5 | 4.78 | | +| A100 | 16 | | 47.08 | 46.27 | 44.8 | 4.84 | | +| A100 | 32 | | 92.89 | 91.34 | 88.35 | 4.89 | | +| A100 | 64 | | 185.3 | 182.71 | 176.48 | 4.76 | | +| | | | | | | | +| T4 | 1 | 28.2 | 24.49 | 23.93 | 23.56 | 3.80 | 16.45 | +| T4 | 2 | 52.77 | 45.7 | 45.88 | 45.06 | 1.40 | 14.61 | +| T4 | 4 | OOM | 85.72 | 85.78 | 84.48 | 1.45 | | +| T4 | 8 | | 149.64 | 150.75 | 148.4 | 0.83 | | +| | | | | | | | +| V100 | 1 | 7.4 | 6.84 | 6.8 | 6.66 | 2.63 | 10.00 | +| V100 | 2 | 13.85 | 12.81 | 12.66 | 12.35 | 3.59 | 10.83 | +| V100 | 4 | OOM | 25.73 | 25.31 | 24.78 | 3.69 | | +| V100 | 8 | | 43.95 | 43.37 | 42.25 | 3.87 | | +| V100 | 16 | | 84.99 | 84.73 | 82.55 | 2.87 | | +| | | | | | | | +| 3090 | 1 | 7.09 | 6.78 | 6.11 | 6.03 | 11.06 | 14.95 | +| 3090 | 4 | 22.69 | 21.45 | 18.67 | 18.09 | 15.66 | 20.27 | +| 3090 | 8 (2) | | 42.59 | 36.75 | 35.59 | 16.44 | | +| 3090 | 16 | | 85.35 | 72.37 | 70.25 | 17.69 | | +| 3090 | 32 (1) | | 162.05 | 138.99 | 134.53 | 16.98 | | +| 3090 | 48 | | 241.91 | 207.75 | | 14.12 | | +| | | | | | | | +| 3090 Ti | 1 | 6.45 | 6.19 | 5.64 | 5.49 | 11.31 | 14.88 | +| 3090 Ti | 4 | 20.32 | 19.31 | 16.9 | 16.37 | 15.23 | 19.44 | +| 3090 Ti | 8 (2) | | 37.93 | 33.05 | 31.99 | 15.66 | | +| 3090 Ti | 16 | | 75.37 | 65.25 | 64.32 | 14.66 | | +| 3090 Ti | 32 (1) | | 142.55 | 124.44 | 120.74 | 15.30 | | +| 3090 Ti | 48 | | 213.19 | 186.55 | | 12.50 | | +| | | | | | | | +| 4090 | 1 | 5.54 | 4.99 | | | | | +| 4090 | 4 | 13.67 | 11.4 | | | | | +| 4090 | 8 (2) | | 19.79 | | | | | +| 4090 | 16 | | 38.62 | | | | | +| 4090 | 32 (1) | | 76.57 | | | | | +| 4090 | 48 | | 114.44 | | | 13.68 | | +| | | | | | | | +| A10 | 1 | 10.59 | 8.81 | 7.51 | 7.35 | 16.57 | 30.59 | +| A10 | 4 | 34.77 | 27.63 | 22.77 | 22.07 | 20.12 | 36.53 | +| A10 | 8 | | 56.19 | 43.53 | 43.86 | 21.94 | | +| A10 | 16 | | 116.49 | 88.56 | 86.64 | 25.62 | | +| A10 | 32 | | 221.95 | 175.74 | 168.18 | 24.23 | | +| A10 | 48 | | 333.23 | 264.84 | | 20.52 | | +| | | | | | | | + + (1) Batch Size >= 32 requires enable_vae_slicing() because of https://github.com/pytorch/pytorch/issues/81665 This is required for PyTorch 1.13.1, and also for PyTorch 2.0 and batch size of 64 -For more details about how this benchmark was run, please refer to [this PR](https://github.com/huggingface/diffusers/pull/2303). +For more details about how this benchmark was run, please refer to [this PR](https://github.com/huggingface/diffusers/pull/2303). \ No newline at end of file From e45029dcd67a30cc060a70e18bfe98ca62096021 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 17 Feb 2023 13:22:21 +0100 Subject: [PATCH 25/25] Apply suggestions from code review Co-authored-by: Sayak Paul --- docs/source/en/_toctree.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 136d2ba759aa..a5e271380ec1 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -56,7 +56,7 @@ - local: optimization/fp16 title: Memory and Speed - local: optimization/torch2.0 - title: Torch2.0 support. + title: Torch2.0 support - local: optimization/xformers title: xFormers - local: optimization/onnx