diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 2f8cf19fea83..a5e271380ec1 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -55,6 +55,8 @@ - sections: - local: optimization/fp16 title: Memory and Speed + - local: optimization/torch2.0 + title: Torch2.0 support - local: optimization/xformers title: xFormers - local: optimization/onnx diff --git a/docs/source/en/optimization/torch2.0.mdx b/docs/source/en/optimization/torch2.0.mdx new file mode 100644 index 000000000000..20af6fc15f73 --- /dev/null +++ b/docs/source/en/optimization/torch2.0.mdx @@ -0,0 +1,200 @@ + + +# Torch2.0 support in Diffusers + +Starting from version `0.13.0`, Diffusers supports the latest optimization from the upcoming [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/) release. These include: +1. Support for native flash and memory-efficient attention without any extra dependencies. +2. [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) support for compiling individual models for extra performance boost. + + +## Installation +To benefit from the native efficient attention and `torch.compile`, we will need to install the nightly version of PyTorch as the stable version is yet to be released. The first step is to install CUDA11.7 or CUDA11.8, +as torch2.0 does not support the previous versions. Once CUDA is installed, torch nightly can be installed using: + +```bash +pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu117 +``` + +## Using efficient attention and torch.compile. + + +1. **Efficient Attention** + + Efficient attention is implemented via the [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) function, which automatically enables flash/memory efficient attention, depending on the input and the GPU type. This is the same as the `memory_efficient_attention` from [xFormers](https://github.com/facebookresearch/xformers) but built natively into PyTorch. + + Efficient attention will be enabled by default in Diffusers if torch2.0 is installed and if `torch.nn.functional.scaled_dot_product_attention` is available. To use it, you can install torch2.0 as suggested above and use the pipeline. For example: + + ```Python + import torch + from diffusers import StableDiffusionPipeline + + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + pipe = pipe.to("cuda") + + prompt = "a photo of an astronaut riding a horse on mars" + image = pipe(prompt).images[0] + ``` + + If you want to enable it explicitly (which is not required), you can do so as shown below. + + ```Python + import torch + from diffusers import StableDiffusionPipeline + from diffusers.models.cross_attention import AttnProccesor2_0 + + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") + pipe.unet.set_attn_processor(AttnProccesor2_0()) + + prompt = "a photo of an astronaut riding a horse on mars" + image = pipe(prompt).images[0] + ``` + + This should be as fast and memory efficient as `xFormers`. + + +2. **torch.compile** + + To get an additional speedup, we can use the new `torch.compile` feature. To do so, we wrap our `unet` with `torch.compile`. For more information and different options, refer to the + [torch compile docs](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). + + ```python + import torch + from diffusers import StableDiffusionPipeline + + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to( + "cuda" + ) + pipe.unet = torch.compile(pipe.unet) + + batch_size = 10 + prompt = "A photo of an astronaut riding a horse on marse." + images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images + ``` + + Depending on the type of GPU it can give between 2-9% speed-up over efficient attention. But note that as of now the speed-up is mostly noticeable on the more recent GPU architectures, such as in the A100. + + Note that compilation will also take some time to complete, so it is best suited for situations where you need to prepare your pipeline once and then perform the same type of inference operations multiple times. + + +## Benchmark + +We conducted a simple benchmark on different GPUs to compare vanilla attention, xFormers, `torch.nn.functional.scaled_dot_product_attention` and `torch.compile+torch.nn.functional.scaled_dot_product_attention`. +For the benchmark we used the the [stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) model with 50 steps. `xFormers` benchmark is done using the `torch==1.13.1` version. The table below summarizes the result that we got. +The `Speed over xformers` columns denotes the speed-up gained over `xFormers` using the `torch.compile+torch.nn.functional.scaled_dot_product_attention`. + + +### FP16 benchmark + +The table below shows the benchmark results for inference using `fp16`. As we can see, `torch.nn.functional.scaled_dot_product_attention` is as fast as `xFormers` (sometimes slightly faster/slower) on all the GPUs we tested. +And using `torch.compile` gives further speed-up up to 10% over `xFormers`, but it's mostly noticeable on the A100 GPU. + +___The time reported is in seconds.___ + +| GPU | Batch Size | Vanilla Attention | xFormers | PyTorch2.0 SDPA | SDPA + torch.compile | Speed over xformers (%) | +| --- | --- | --- | --- | --- | --- | --- | +| A100 | 10 | 12.02 | 8.7 | 8.79 | 7.89 | 9.31 | +| A100 | 16 | 18.95 | 13.57 | 13.67 | 12.25 | 9.73 | +| A100 | 32 (1) | OOM | 26.56 | 26.68 | 24.08 | 9.34 | +| A100 | 64(2) | | 52.51 | 53.03 | 47.81 | 8.95 | +| | | | | | | | +| T4 | 4 | 38.81 | 30.09 | 29.74 | 27.55 | 8.44 | +| T4 | 8 | OOM | 55.71 | 55.99 | 53.85 | 3.34 | +| T4 | 10 | OOM | 68.96 | 69.86 | 65.35 | 5.23 | +| T4 | 16 | OOM | 111.47 | 113.26 | 106.93 | 4.07 | +| | | | | | | | +| V100 | 4 | 9.84 | 8.16 | 8.09 | 7.65 | 6.25 | +| V100 | 8 | OOM | 15.62 | 15.44 | 14.59 | 6.59 | +| V100 | 10 | OOM | 19.52 | 19.28 | 18.18 | 6.86 | +| V100 | 16 | OOM | 30.29 | 29.84 | 28.22 | 6.83 | +| | | | | | | | +| A10 | 4 | 13.94 | 9.81 | 10.01 | 9.35 | 4.69 | +| A10 | 8 | 27.09 | 19 | 19.53 | 18.33 | 3.53 | +| A10 | 10 | 33.69 | 23.53 | 24.19 | 22.52 | 4.29 | +| A10 | 16 | OOM | 37.55 | 38.31 | 36.81 | 1.97 | +| A10 | 32 (1) | | 77.19 | 78.43 | 76.64 | 0.71 | +| A10 | 64 (1) | | 173.59 | 158.99 | 155.14 | 10.63 | +| | | | | | | | +| 3090 | 4 | 10.04 | 7.82 | 7.89 | 7.47 | 4.48 | +| 3090 | 8 | 19.27 | 14.97 | 15.04 | 14.22 | 5.01 | +| 3090 | 10| 24.08 | 18.7 | 18.7 | 17.69 | 5.40 | +| 3090 | 16 | OOM | 29.06 | 29.06 | 28.2 | 2.96 | +| 3090 | 32 (1) | | 58.05 | 58 | 54.88 | 5.46 | +| 3090 | 64 (1) | | 126.54 | 126.03 | 117.33 | 7.28 | +| | | | | | | | +| 3090 Ti | 4 | 9.07 | 7.14 | 7.15 | 6.81 | 4.62 | +| 3090 Ti | 8 | 17.51 | 13.65 | 13.72 | 12.99 | 4.84 | +| 3090 Ti | 10 (2) | 21.79 | 16.85 | 16.93 | 16.02 | 4.93 | +| 3090 Ti | 16 | OOM | 26.1 | 26.28 | 25.46 | 2.45 | +| 3090 Ti | 32 (1) | | 51.78 | 52.04 | 49.15 | 5.08 | +| 3090 Ti | 64 (1) | | 112.02 | 112.33 | 103.91 | 7.24 | + + + +### FP32 benchmark + +The table below shows the benchmark results for inference using `fp32`. As we can see, `torch.nn.functional.scaled_dot_product_attention` is as fast as `xFormers` (sometimes slightly faster/slower) on all the GPUs we tested. +Using `torch.compile` with efficient attention gives up to 18% performance improvement over `xFormers` in Ampere cards, and up to 20% over vanilla attention. + +| GPU | Batch Size | Vanilla Attention | xFormers | PyTorch2.0 SDPA | SDPA + torch.compile | Speed over xformers (%) | Speed over vanilla (%) | +| --- | --- | --- | --- | --- | --- | --- | --- | +| A100 | 4 | 16.56 | 12.42 | 12.2 | 11.84 | 4.67 | 28.50 | +| A100 | 10 | OOM | 29.93 | 29.44 | 28.5 | 4.78 | | +| A100 | 16 | | 47.08 | 46.27 | 44.8 | 4.84 | | +| A100 | 32 | | 92.89 | 91.34 | 88.35 | 4.89 | | +| A100 | 64 | | 185.3 | 182.71 | 176.48 | 4.76 | | +| | | | | | | | +| T4 | 1 | 28.2 | 24.49 | 23.93 | 23.56 | 3.80 | 16.45 | +| T4 | 2 | 52.77 | 45.7 | 45.88 | 45.06 | 1.40 | 14.61 | +| T4 | 4 | OOM | 85.72 | 85.78 | 84.48 | 1.45 | | +| T4 | 8 | | 149.64 | 150.75 | 148.4 | 0.83 | | +| | | | | | | | +| V100 | 1 | 7.4 | 6.84 | 6.8 | 6.66 | 2.63 | 10.00 | +| V100 | 2 | 13.85 | 12.81 | 12.66 | 12.35 | 3.59 | 10.83 | +| V100 | 4 | OOM | 25.73 | 25.31 | 24.78 | 3.69 | | +| V100 | 8 | | 43.95 | 43.37 | 42.25 | 3.87 | | +| V100 | 16 | | 84.99 | 84.73 | 82.55 | 2.87 | | +| | | | | | | | +| 3090 | 1 | 7.09 | 6.78 | 6.11 | 6.03 | 11.06 | 14.95 | +| 3090 | 4 | 22.69 | 21.45 | 18.67 | 18.09 | 15.66 | 20.27 | +| 3090 | 8 (2) | | 42.59 | 36.75 | 35.59 | 16.44 | | +| 3090 | 16 | | 85.35 | 72.37 | 70.25 | 17.69 | | +| 3090 | 32 (1) | | 162.05 | 138.99 | 134.53 | 16.98 | | +| 3090 | 48 | | 241.91 | 207.75 | | 14.12 | | +| | | | | | | | +| 3090 Ti | 1 | 6.45 | 6.19 | 5.64 | 5.49 | 11.31 | 14.88 | +| 3090 Ti | 4 | 20.32 | 19.31 | 16.9 | 16.37 | 15.23 | 19.44 | +| 3090 Ti | 8 (2) | | 37.93 | 33.05 | 31.99 | 15.66 | | +| 3090 Ti | 16 | | 75.37 | 65.25 | 64.32 | 14.66 | | +| 3090 Ti | 32 (1) | | 142.55 | 124.44 | 120.74 | 15.30 | | +| 3090 Ti | 48 | | 213.19 | 186.55 | | 12.50 | | +| | | | | | | | +| 4090 | 1 | 5.54 | 4.99 | | | | | +| 4090 | 4 | 13.67 | 11.4 | | | | | +| 4090 | 8 (2) | | 19.79 | | | | | +| 4090 | 16 | | 38.62 | | | | | +| 4090 | 32 (1) | | 76.57 | | | | | +| 4090 | 48 | | 114.44 | | | 13.68 | | +| | | | | | | | +| A10 | 1 | 10.59 | 8.81 | 7.51 | 7.35 | 16.57 | 30.59 | +| A10 | 4 | 34.77 | 27.63 | 22.77 | 22.07 | 20.12 | 36.53 | +| A10 | 8 | | 56.19 | 43.53 | 43.86 | 21.94 | | +| A10 | 16 | | 116.49 | 88.56 | 86.64 | 25.62 | | +| A10 | 32 | | 221.95 | 175.74 | 168.18 | 24.23 | | +| A10 | 48 | | 333.23 | 264.84 | | 20.52 | | +| | | | | | | | + + +(1) Batch Size >= 32 requires enable_vae_slicing() because of https://github.com/pytorch/pytorch/issues/81665 +This is required for PyTorch 1.13.1, and also for PyTorch 2.0 and batch size of 64 + +For more details about how this benchmark was run, please refer to [this PR](https://github.com/huggingface/diffusers/pull/2303). \ No newline at end of file diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index baccdd83f202..27a515645de6 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -99,7 +99,10 @@ def __init__( self.to_out.append(nn.Dropout(dropout)) # set attention processor - processor = processor if processor is not None else CrossAttnProcessor() + # We use the AttnProccesor2_0 by default when torch2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + if processor is None: + processor = AttnProccesor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() self.set_processor(processor) def set_use_memory_efficient_attention_xformers( @@ -463,6 +466,50 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No return hidden_states +class AttnProccesor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProccesor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, inner_dim = hidden_states.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + class LoRAXFormersCrossAttnProcessor(nn.Module): def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None): super().__init__()