diff --git a/docs/source/en/optimization/attention_backends.md b/docs/source/en/optimization/attention_backends.md index e603878a6383..8be2c0603009 100644 --- a/docs/source/en/optimization/attention_backends.md +++ b/docs/source/en/optimization/attention_backends.md @@ -81,6 +81,45 @@ with attention_backend("_flash_3_hub"): > [!TIP] > Most attention backends support `torch.compile` without graph breaks and can be used to further speed up inference. +## Checks + +The attention dispatcher includes debugging checks that catch common errors before they cause problems. + +1. Device checks verify that query, key, and value tensors live on the same device. +2. Data type checks confirm tensors have matching dtypes and use either bfloat16 or float16. +3. Shape checks validate tensor dimensions and prevent mixing attention masks with causal flags. + +Enable these checks by setting the `DIFFUSERS_ATTN_CHECKS` environment variable. Checks add overhead to every attention operation, so they're disabled by default. + +```bash +export DIFFUSERS_ATTN_CHECKS=yes +``` + +The checks are run now before every attention operation. + +```py +import torch + +query = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda") +key = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda") +value = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda") + +try: + with attention_backend("flash"): + output = dispatch_attention_fn(query, key, value) + print("✓ Flash Attention works with checks enabled") +except Exception as e: + print(f"✗ Flash Attention failed: {e}") +``` + +You can also configure the registry directly. + +```py +from diffusers.models.attention_dispatch import _AttentionBackendRegistry + +_AttentionBackendRegistry._checks_enabled = True +``` + ## Available backends Refer to the table below for a complete list of available attention backends and their variants.