# Flash Attention

Flash Attention is an optimized attention mechanism that addresses the issues of the classical attention mechanism, such as high memory requirements and inefficient GPU utilization. This notebook (made by [n.luneva](https://github.com/lwtztea)) will help you understand how Flash Attention works and demonstrate its practical applications.

[Paper](https://arxiv.org/pdf/2205.14135) and [HF explanation](https://huggingface.co/docs/text-generation-inference/conceptual/flash_attention).


## 1. Theory

### What is the Attention Mechanism?

Attention mechanism is used in transformer architectures to compute relationships between elements in a sequence. The classical attention mechanism requires calculating an attention scores matrix of size $n \times n$, which leads to the following problems:

1. **High memory requirements**: $O(n^2)$.
2. **Inefficient GPU Utilization.** Large matrices do not fit well in GPU memory.
3. **Scaling challenges.** Execution time grows quadratically with increasing sequence length.

### What is Flash Attention?

Flash Attention is an optimized implementation of the attention mechanism that:

- Reduces memory consumption from $O(n^2)$ to $O(n)$.
- Uses streaming computation and tiling for efficient operation on GPUs.
- Ensures numerical stability when computing softmax.

<img src = https://raw.githubusercontent.com/lwtztea/ml_pic/269a45a/week_6/flash_attention.png width = 2000 >

A reason why we need to split inputs into blocks is that SRAM does not have enough space to save full intermediate results in it.

<img src = https://raw.githubusercontent.com/lwtztea/ml_pic/957bf7b/week_6/memory_hierarchy.png width = 1000 >

<img src = https://raw.githubusercontent.com/lwtztea/ml_pic/957bf7b/week_6/fa_algorithm.png width = 2000 >

You can check computation animation on YouTube:

* Standard Attention — [link](https://www.youtube.com/watch?v=-EF-KIscwJw&list=PLBWdTDczuFNrxnetcH-5CVLRP1u4Ua87m&index=1)
* Flash Attention — [link](https://www.youtube.com/watch?v=cq3jQ-Bbmzs&list=PLBWdTDczuFNrxnetcH-5CVLRP1u4Ua87m&index=5)

A little about block matrix multiplication — [link](https://mathworld.wolfram.com/BlockMatrix.html).

### Another Secret — Fused Kernel

Fused Kernel is an optimization technique in which multiple computational operations are combined into a single "kernel" computation at the GPU level. This minimizes overhead associated with moving data between different memory levels (e.g., between GPU global memory and registers) and reduces the number of kernel launches.

In a traditional approach, each operation (e.g., matrix multiplication, softmax or scaling) is executed separately, leading to the following issues:

* **Frequent memory accesses.** Intermediate results of each operation are written to GPU global memory and then read back for the next operation.
* **Multiple kernel launches.** Each operation requires a separate kernel launch, which increases latency.
* **Inefficient use of GPU resources.** GPU global memory is slower than local memory or registers, so frequent accesses to it reduce performance.

### Fused Kernel in the Context of Flash Attention

Flash Attention actively uses fused kernels to combine multiple stages of the attention mechanism into a single computation.

**Traditional Approach:**
1. Compute attention scores: $S = QK^T$.
2. Store $S$ in global memory.
3. Apply scaling $S' = S / \sqrt{d_k}$.
4. Apply softmax: $A = \text{softmax}(S')$.
5. Store $A$ in global memory.
6. Compute the output: $O = AV$.

Each step requires access to global memory and a separate kernel launch.

**Fused Kernel Approach:**
1. Compute attention scores, scaling and softmax in a single kernel, storing only intermediate results in registers.
2. Directly compute the output $O = AV$.

Fused kernels are usually implemented in low-level programming languages such as CUDA for maximum optimization. However, modern libraries like PyTorch and TensorFlow provide high-level interfaces for working with it.

<img src = https://raw.githubusercontent.com/lwtztea/ml_pic/957bf7b/week_6/fused_kernel.png width = 1000 >

## 2. Implementation

In this section, we will look at an example implementation of Flash Attention in Python using the PyTorch library.

In [1]:
import torch
import torch.nn as nn

In [2]:
def flash_attention(Q, K, V, block_size=64):
    """
    Args:
        Q (torch.Tensor): Query [batch_size, seq_len, d_k].
        K (torch.Tensor): Key [batch_size, seq_len, d_k].
        V (torch.Tensor): Value [batch_size, seq_len, d_v].
        block_size (int): tiling block size.
    Returns:
        torch.Tensor: Attention output.
    """
    batch_size, seq_len, d_k = Q.shape
    device = Q.device

    O = torch.zeros_like(V)
    l_i = torch.zeros(batch_size, seq_len, device=device)
    m_i = torch.full((batch_size, seq_len), float("-inf"), device=device)

    for start in range(0, seq_len, block_size):
        end = min(start + block_size, seq_len)

        Q_block = Q[:, start:end, :]
        S_block = torch.matmul(Q_block, K.transpose(-2, -1)) / (d_k**0.5)

        M_block = torch.max(S_block, dim=-1, keepdim=True).values
        P_block = torch.exp(S_block - M_block)
        L_block = torch.sum(P_block, dim=-1, keepdim=True)

        O[:, start:end, :] = torch.matmul(P_block / L_block, V)

        l_i[:, start:end] = L_block.squeeze(-1)
        m_i[:, start:end] = M_block.squeeze(-1)

    return O

In [3]:
batch_size, seq_len, d_k = 2, 10, 8
Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_k)

output = flash_attention(Q, K, V)
print("Attention output:")
print(output)

Attention output:
tensor([[[ 5.7712e-01,  7.4143e-02,  9.4435e-01, -6.2278e-01, -7.9697e-01,
          -7.3283e-01,  6.3864e-01,  1.5707e-01],
         [ 3.8677e-01, -1.3144e-01,  8.9284e-01, -7.3510e-01, -3.9344e-01,
          -2.4009e-01,  3.3335e-01, -9.2881e-02],
         [ 1.8090e-01, -2.2231e-01,  6.1934e-01, -7.4626e-01,  8.5221e-02,
          -3.1526e-01,  1.2311e-01,  2.9450e-01],
         [ 2.8106e-02, -1.8882e-01,  3.4479e-01, -5.2488e-01,  9.6422e-02,
          -3.0207e-01,  7.6238e-02,  2.8863e-01],
         [-2.1090e-01, -6.2028e-01, -6.0037e-01, -2.6804e-01,  1.0242e+00,
          -3.3890e-02,  3.3172e-02,  6.0404e-01],
         [ 1.0859e-01, -1.8317e-02,  5.5349e-01, -6.8297e-01, -2.4812e-02,
          -1.1835e-01,  7.0378e-02, -6.5696e-02],
         [ 4.4700e-01, -6.3494e-01,  1.5653e-01, -7.0236e-01, -6.4877e-02,
           4.9781e-02, -6.6919e-02,  4.3952e-02],
         [ 3.0520e-01, -2.5359e-01,  6.6639e-01, -6.1226e-01, -4.4767e-01,
          -1.7331e-01,  1.8060e-

## 3. Application

Let's apply Flash Attention in the context of an NLP task. We'll create a simple transformer model with Flash Attention which will be used as an attention layer.

In [4]:
class TransformerWithFlashAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        Q = Q.reshape(batch_size * self.num_heads, seq_len, self.d_k)
        K = K.reshape(batch_size * self.num_heads, seq_len, self.d_k)
        V = V.reshape(batch_size * self.num_heads, seq_len, self.d_k)
        output = flash_attention(Q, K, V)

        output = output.view(batch_size, self.num_heads, seq_len, self.d_k)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

        return self.W_o(output)

In [5]:
model = TransformerWithFlashAttention(d_model=128, num_heads=8)
x = torch.randn(2, 10, 128)
output = model(x)
print("Model output:")
print(output)

Model output:
tensor([[[-0.0225, -0.0889,  0.0936,  ...,  0.0526, -0.0591,  0.0948],
         [-0.0018, -0.1090,  0.1823,  ...,  0.0269,  0.0462,  0.0804],
         [ 0.0081, -0.0597,  0.1509,  ...,  0.0789, -0.0389,  0.1391],
         ...,
         [ 0.0415, -0.0422,  0.1358,  ...,  0.1021, -0.0331,  0.1239],
         [-0.0240, -0.1373,  0.1252,  ...,  0.0315, -0.0362,  0.0909],
         [-0.0701, -0.0948,  0.1976,  ...,  0.0267,  0.0213,  0.1033]],

        [[ 0.0845, -0.1266, -0.0191,  ..., -0.0731, -0.0965,  0.2096],
         [ 0.0781, -0.1326, -0.1218,  ..., -0.0705, -0.1542,  0.1669],
         [ 0.0564, -0.1580, -0.0757,  ..., -0.0909, -0.1825,  0.1700],
         ...,
         [ 0.0809, -0.0882, -0.0278,  ..., -0.0616, -0.1072,  0.1951],
         [ 0.0396, -0.1612, -0.0817,  ..., -0.1417, -0.0749,  0.2213],
         [ 0.0315, -0.0974, -0.0551,  ..., -0.0613, -0.1030,  0.1355]]],
       grad_fn=<ViewBackward0>)


## 4. Conclusion

Flash Attention is a powerful tool for improving the performance of transformer models. It enables efficient processing of long sequences and reduces memory requirements.