In [None]:
!pip install -qU diffusers accelerate transformers huggingface_hub peft torchao

In [None]:
import time
from huggingface_hub import notebook_login
notebook_login()

# Acclerate inference of text-to-image diffusion models

Diffusion models are slower than GANs because of the iterative and sequantial reverse diffusion process. There are several techniques that can address this limitation such as progressive timestep distillation (`LCM LoRA`), model compression (`SSD-1B`), and reusing adjacent features of the denoiser (`DeepCache`).

In this session, we will progressively apply the optimizations found in PyTorch 2 to reduce inference latency.

## Baseline

Disable reduced precision and the `scaled_dot_product_attention` (SDPA) function which is automatically used by Diffusers:

In [None]:
from diffusers import StableDiffusionXLPipeline

# Load the pipeline in full-precision and place its model components on CUDA
pipe = StableDiffusionXLPipeline.from_pretrained(
    'stabilityai/stable-diffusion-xl-base-1.0'
).to('cuda')

In [None]:
# Run the attention ops without4 SDPA
pipe.unet.set_default_attn_processor()
pipe.vae.set_default_attn_processor()

In [None]:
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

t_start = time.time()
image = pipe(
    prompt,
    num_inference_steps=30,
).images[0]
print(f'Inference time: {time.time() - t_start} seconds')

## bfloat16

Enable the first optimization, reduced precision or more specifically `bfloat16`:
* using a reduced numerical precision (such as `float16` or `bfloat16`) for inference does not affect the generation quality but significantly improves latency.
* The benefits of using `bfloat16` compared to `float16` are hardware dependent, but modern GPUs tend to favor `bfloat16`.
* `bfloat16` is much more resilient when used with quantization compared with `float16`, but more recent versions of the quantization library (`torchao`) we used do not have numerical issues with `float16`.

In [None]:
from diffusers import StableDiffusionXLPipeline
import torch

pipe = StableDiffusionXLPipeline(
    'stabilityai/stable-diffusion-xl-base-1.0',
    torch_dtype=torch.bfloat16,
).to('cuda')

In [None]:
# Run the attention ops without SDPA
pipe.unet.set_default_attn_processor()
pipe.vae.set_default_attn_processor()

In [None]:
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

t_start = time.time()
image = pipe(
    prompt,
    num_inference_steps=30,
).images[0]
print(f'Inference time: {time.time() - t_start} seconds')

## SDPA

Attention blocks are intensive to run, but with PyTorch's `scaled_dot_product_attention` function, it is a lot more efficient. This function is used by default in Diffusers so we do not need to make any changes to the code.

In [None]:
from diffusers import StableDiffusionXLPipeline
import torch

pipe = StableDiffusionXLPipeline.from_pretrained(
    'stabilityai/stable-diffusion-xl-base-1.0',
    torch_dtype=torch.bfloat16,
).to('cuda')

In [7]:
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

t_start = time.time()
image = pipe(
    prompt,
    num_inference_steps=30,
).images[0]
print(f'Inference time: {time.time() - t_start} seconds')

  0%|          | 0/30 [00:00<?, ?it/s]

Inference time: 179.09928822517395 seconds


## torch.compile

Configure a few compiler flags

In [8]:
from diffusers import StableDiffusionXLPipeline
import torch

torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

Change the UNet and VAE memory layout to "channels_last" when compiling them to ensure maximum speed.

In [9]:
pipe.unet.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)

AutoencoderKL(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (1): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (c

Now compile and perform inference:

In [10]:
# Compile UNet and VAE
pipe.unet = torch.compile(
    pipe.unet,
    mode='max-autotune',
    fullgraph=True,
)
pipe.vae.decode = torch.compile(
    pipe.vae.decode,
    mode='max-autotune',
    fullgraph=True,
)

In [None]:
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

t_start = time.time()
image = pipe(
    prompt,
    num_inference_steps=30,
).images[0]
print(f'Inference time: {time.time() - t_start} seconds')

`torch.compile` offers different backends and modes. For maximum inference speed, we can use "max-autotune" for the inductor backend.

"max-autotune" uses CUDA graphs and optimizes the compilation graph specifically for latency. CUDA graphs greatly reduces the overhead of launching GPU oeprations by using a mechanism to launch multiple GPU operations through a single CPU operation.

### Prevent graph breaks

Specifying `fullgraph=True` ensures there are no graph breaks in the underlying model to take full advantage of `torch.compile` without any performance degradation.

### Remove GPU sync after compilation

During the iterative reverse diffusion process, the `step()` function is called on the scheduler each time after the denoiser predicts the less noisy latent embedings.

Inside `step()`, the `sigmas` variable is indexed which when placed on the GPU, causes a communication sync between the CPU and GPU. This introduces latency and it becomes more evident when the denoiser has already been compiled. If the `sigmas` array always stays on the CPU, the CPU and GPU sync does not occur and we do not get any latency.

## Combine the attention block's projection matrices

The UNet and VAE in SDXL use Transformer-like blocks which consists of attention blocks and feed-forward blocks.

In an attention block, the input is projected into three sub-spaces using three different projection matrices - Q, K, and V. These projections are performed separately on the input. However, we can horizontally combine the projection matrices into a single matrix and perform the projection in one step. This increases the sizes the matrix multiplications of the input projections and improves the impact of quantization:

In [None]:
pipe.fuse_qkv_projections()

## Dynamic quantization

We can also use the ultra-lightweight PyTorch quantization library, `torchao`, to apply **dynamic int8 quantization** to the UNet and VAE.

Quantization adds additional conversion overhead to the model that is hopefully made up for by faster `matmul`'s (dynamic quantization). If the `matmul`'s are too small, these techniques may degrade performance.

First, we need to configure all the compiler tags:

In [12]:
from diffusers import StableDiffusionXLPipeline
import torch

torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.epilogue_fusion = False
torch._inductor.coordinate_descent_check_all_directions = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

Certain linear layers in the UNet and VAE do not benefit from dynamic int8 quantization. We can filter out those layers with the `dynamic_quant_filter_fn` below:

In [13]:
def dynamic_quant_filter_fn(mod, *args):
    return (
        isinstance(mod, torch.nn.Linear)
        and mod.in_features > 16
        and (mod.in_features, mod.out_features)
        not in [
            (1280, 640),
            (1920, 1280),
            (1920, 640),
            (2048, 1280),
            (2048, 2560),
            (2560, 1280),
            (256, 128),
            (2816, 1280),
            (320, 640),
            (512, 1536),
            (512, 256),
            (512, 512),
            (640, 1280),
            (640, 1920),
            (640, 320),
            (640, 5120),
            (640, 640),
            (960, 320),
            (960, 640),
        ]
    )


def conv_filter_fn(mod, *args):
    return (
        isinstance(mod, torch.nn.Conv2d)
        and mod.kernel_size == (1, 1)
        and 128 in [mod.in_channels, mod.out_channels]
    )

Finally, we can apply all the optimizations:

In [None]:
# SDPA + bfloat16
pipe = StableDiffusionXLPipeline.from_pretrained(
    'stabilityai/stable-diffusion-xl-base-1.0',
    torch_dtype=torch.bfloat16,
).to('cuda')

# Combine attention projection matrices
pipe.fuse_qkv_projections()

# Change the memory layout
pipe.unet.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)

Since dynamic quantization is only limited to the linear layers, we convert the appropriate pointwise convolution layers into linear layers to maximize its benefit:

In [None]:
from torchao import swap_conv2d_1x1_to_linear

swap_conv2d_1x1_to_linear(pipe.unet, conv_filter_fn)
swap_conv2d_1x1_to_linear(pipe.vae, conv_filter_fn)

Apply dynamic quantization:

In [None]:
from torchao import apply_dynamic_quant

apply_dynamic_quant(pipe.unet, dynamic_quant_filter_fn)
apply_dynamic_quant(pipe.vae, dynamic_quant_filter_fn)

Finally, we can compile and perform inference:

In [None]:
pipe.unet = torch.compile(
    pipe.unet,
    mode='max-autotune',
    fullgraph=True,
)
pipe.vae.decode = torch.compile(
    pipe.vae.decode,
    mode='max-autotune',
    fullgraph=True,
)

In [None]:
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

t_start = time.time()
image = pipe(
    prompt,
    num_inference_steps=30,
).images[0]
print(f'Inference time: {time.time() - t_start} seconds')