```
ImportError: PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1.
```

- https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
torch.__version__

'2.3.1+cu121'

In [3]:
Q, K, V = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)

In [4]:
X = F.scaled_dot_product_attention(Q, K, V)
X

tensor([[[-1.3321, -0.3489,  0.3015, -0.3912,  0.9867,  0.3137, -0.0691,
          -1.2593],
         [-1.0882,  0.2506,  0.6491,  0.1360,  0.5238, -0.2448, -0.0820,
          -0.6171],
         [-1.0012,  0.3990,  0.6441, -0.0277,  0.5325, -0.2564, -0.0607,
          -0.6404]],

        [[ 0.6091,  0.0708,  0.6188,  0.3252, -0.1598,  0.4197, -0.2335,
           0.0630],
         [ 0.5285,  0.3890, -0.2649,  0.3706, -0.3839,  0.1963, -0.6242,
           0.2312],
         [ 0.4048,  0.0762,  0.3777,  0.4689, -0.2978,  0.2754, -0.6429,
           0.1037]]], device='cuda:0')

In [5]:
X.shape

torch.Size([2, 3, 8])

$$
\text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$

In [6]:
F.softmax(torch.bmm(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(8)), dim=-1).shape

torch.Size([2, 3, 3])

In [7]:
torch.bmm(F.softmax(torch.bmm(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(8)), dim=-1), V)

tensor([[[-1.3321, -0.3489,  0.3015, -0.3912,  0.9867,  0.3137, -0.0691,
          -1.2593],
         [-1.0882,  0.2506,  0.6491,  0.1360,  0.5238, -0.2448, -0.0820,
          -0.6171],
         [-1.0012,  0.3990,  0.6441, -0.0277,  0.5325, -0.2564, -0.0607,
          -0.6404]],

        [[ 0.6091,  0.0708,  0.6188,  0.3252, -0.1598,  0.4197, -0.2335,
           0.0630],
         [ 0.5285,  0.3890, -0.2649,  0.3706, -0.3839,  0.1963, -0.6242,
           0.2312],
         [ 0.4048,  0.0762,  0.3777,  0.4689, -0.2978,  0.2754, -0.6429,
           0.1037]]], device='cuda:0')

## SDPA

- The default implementation runs in 26186.948 microseconds
- The math implementation runs in 50155.869 microseconds
- The flash attention implementation runs in 26189.985 microseconds
- The memory efficient implementation runs in 48395.111 microseconds
- PyTorch’s `torch.nn.functional.scaled_dot_product_attention` (SDPA) can also call `FlashAttention` and `memory-efficient attention kernels` under the hood. SDPA support is currently being added natively in Transformers and is used by default for torch>=2.1.1 when an implementation is available. You may also set attn_implementation="sdpa" in from_pretrained() to explicitly request SDPA to be used.

In [2]:
import torch.utils.benchmark as benchmark
def benchmark_sdpa(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6

In [3]:
# Lets define the hyper-parameters of our input
bs = 32
seq_len = 2*1024
n_heads = 32
embed_dimen = 224

dtype = torch.float16


In [4]:
Q = torch.rand(bs, n_heads, seq_len, embed_dimen, device=device, dtype=dtype)
K = torch.rand(bs, n_heads, seq_len, embed_dimen, device=device, dtype=dtype)
V = torch.rand(bs, n_heads, seq_len, embed_dimen, device=device, dtype=dtype)

In [5]:
f"The default implementation runs in {benchmark_sdpa(F.scaled_dot_product_attention, Q, K, V):.3f} microseconds"

'The default implementation runs in 26101.244 microseconds'

In [6]:
from torch.nn.attention import SDPBackend, sdpa_kernel

with sdpa_kernel(SDPBackend.MATH):
    math_time=benchmark_sdpa(F.scaled_dot_product_attention, Q, K, V)
    print(f"The math implementation runs in {math_time:.3f} microseconds")

OutOfMemoryError: CUDA out of memory. Tried to allocate 8.00 GiB. GPU 

In [13]:
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    try:
        flash_time=benchmark_sdpa(F.scaled_dot_product_attention, Q, K, V)
        print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")

The flash attention implementation runs in 26189.985 microseconds


In [14]:
Q = torch.rand(bs, n_heads, seq_len, embed_dimen, device='cuda:1', dtype=dtype)
K = torch.rand(bs, n_heads, seq_len, embed_dimen, device='cuda:1', dtype=dtype)
V = torch.rand(bs, n_heads, seq_len, embed_dimen, device='cuda:1', dtype=dtype)

In [15]:
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
    try:
        efficient_time=benchmark_sdpa(F.scaled_dot_product_attention, Q, K, V)
        print(f"The memory efficient implementation runs in {efficient_time:.3f} microseconds")
    except RuntimeError:
        print("EfficientAttention is not supported. See warnings for reasons.")

The memory efficient implementation runs in 48395.111 microseconds


## Causal Self Attention

- https://github.com/karpathy/nanoGPT/blob/master/model.py

In [50]:
embed_dim = 12
n_heads = 2
assert embed_dim % n_heads == 0

In [51]:
head_dim = embed_dim // (n_heads * 3)
head_dim

2

In [33]:
# W_q, W_k, W_v
c_attn = nn.Linear(embed_dim, 3 * embed_dim, bias=False)

In [40]:
X = torch.randn(2, 5, embed_dim)

In [41]:
QKV = c_attn(X)
QKV.shape

torch.Size([2, 5, 36])

In [42]:
Q, K, V = QKV.chunk(3, -1)

In [43]:
torch.allclose(QKV[:, :, :12], Q)

True

In [44]:
Q.shape

torch.Size([2, 5, 12])

In [52]:
Q = Q.view(1, -1, n_heads, embed_dim//n_heads).transpose(1, 2)
K = K.view(1, -1, n_heads, embed_dim//n_heads).transpose(1, 2)
V = V.view(1, -1, n_heads, embed_dim//n_heads).transpose(1, 2)