# Implement Attention from Scratch

# Problem Statement

Implement a **Scaled Dot-Product Attention** mechanism from scratch using PyTorch. Mission is to replicate what PyTorch's built-in `scaled_dot_product_attention` does ‚Äî manually.

This core component is essential in Transformer architectures and helps models focus on relevant parts of a sequence. You'll test your implementation against PyTorch's native one to ensure you nailed it.

### Requirements
1. Define the Function:
   - Create a function `scaled_dot_product_attention(q, k, v, mask=None)` that:
     - Computes attention scores via the dot product of query and key vectors.
     - Scales the scores using the square root of the key dimension.
     - Applies an optional mask to the scores.
     - Applies softmax to convert scores into attention weights.
     - Uses these weights to compute a weighted sum of values (V).
2. Test Your Work:
   - Use sample tensors for query (Q), key (K), and value (V).
   - Compare the result of your custom implementation with PyTorch's `F.scaled_dot_product_attention` using an assert to check numerical accuracy.

### Constraints
- ‚ùå Do NOT use F.scaled_dot_product_attention inside your custom function ‚Äî that defeats the whole point.
- ‚úÖ Your implementation must handle batch dimensions correctly.
- ‚úÖ Support optional masking for future tokens or padding.
- ‚úÖ Use only PyTorch ops ‚Äî no cheating with external attention libs.

üí° Hint:
- Use `torch.matmul()` to compute dot products and `F.softmax()` for the final attention weights.
- The mask (if used) should be applied **before** the softmax using `masked_fill`


### Rephrase

Create a function `scaled_dot_product_attention(q, k, v, mask=None)` that manually replicates PyTorch's built-in attention.

- **Compute attention scores**: Dot product between queries and keys using `torch.matmul()`
- **Scale scores**: Divide by square root of key dimension (`sqrt(d_k)`)
- **Apply mask** (optional): Use `masked_fill()` for future tokens or padding
- **Compute attention weights**: Apply `F.softmax()` to scaled scores.
- **Compute output**: Weighted sum of values using attention weights

**Validation**: Compare your implementation with `F.scaled_dot_product_attention()` using numerical assertion.

**Constraints**:
- Use only basic PyTorch operations (no high-level attention functions)
- Handle batch dimensions correctly
- Support optional masking

# Why is it important - understand how the Transformer "brain" works.

The goal is to replicate the logic of F.scaled_dot_product_attention, breaking it down into key steps:

- **Dot-product:** `Q * K·µÄ` ‚Üí "How similar is each query to each key?"
- **Scaling:** `‚àöd‚Çñ` ‚Üí prevents softmax from exploding at large `d‚Çñ` (due to reduced variance of dot-products),
- **Masking (optional)**: `masked_fill(-inf)` ‚Üí blocks attention to forbidden positions (padding, future tokens),
- **Softmax**: ‚Üí converts scores into probabilities (weights),
Weighted sum: weights V ‚Üí aggregates information from values ‚Äã‚Äãbased on relevance.

Important:
- Support for batch dimensions `(Batch_size, Num_Heads, Sequence_Length, Dimension_per_head)`,
- Numerical equivalence with F.scaled_dot_product_attention (checked via `torch.allclose(..., atol=1e-6)),`
- No internal calls to `F.scaled_dot_product_attention` ‚Äî only basic torch operations.

# Theoretic

Attention is a mechanism that lets neural networks focus on specific parts of an input sequence.

A fundamental type is Scaled Dot-Product Attention (used in Transformer). It has three inputs:

- Query (Q): The current token trying to gather information.
- Key (K): A representation of each token in the sequence that‚Äôs available to be attended to.
- Value (V): What each token provides if selected by the attention mechanism

### Scaled Dot-Product Attention Concept

| Component | Mathematics | Analogy | Dimensionality | Why is it needed? |
|-----------|------------|----------|-------------|--------------|
| **Query (Q)** | Query vector | "What am I looking for?" | `[batch, seq_len, d_k]` | Represents the interest of the current position |
| **Key (K)** | Key vector | "What can I offer?" | `[batch, seq_len, d_k]` | Characterizes how relevant a token is to others |
| **Value (V)** | Value vector | "What information do I carry?" | `[batch, seq_len, d_v]` | Contains the actual information for aggregation |
| **`Q¬∑K·µÄ`** | `matmul(Q, K.transpose(-2,-1))` | "How well do the questions and answers match?" | `[batch, seq_len, seq_len]` | Calculates the pairwise similarity of all tokens |
| **Scaling** | `√∑ ‚àöd_k` | "Normalization of estimates" | - | Stabilizes gradients at high dimensions |
| **Masking** | `masked_fill(mask, -1e9)` | "Ignore forbidden positions" | `[batch, seq_len, seq_len]` | Prevents attention to padding/future tokens |
| **Softmax** | `exp(x) / ‚àëexp(x)` | "Convert to probabilities" | `[batch, seq_len, seq_len]` | Converts scores to weights (sum=1) |
| **Output** | `weights @ V` | "Weighted summation of information" | `[batch, seq_len, d_v]` | Aggregates values ‚Äã‚Äãby relevance |

### Scaled Dot-Product Attention calculation step-by-step:
1. We measure how relevant each key `K` is to our query `Q` using a dot product: `Q¬∑K·µÄ`.
2. To keep the values stable for large embeddings, we divide by `‚àöd_k`, where `d_k` is the dimensionality of the key vectors: `Scaling = Q¬∑K·µÄ / ‚àöd_k`.
 When the `d_k` is large, the dot product can grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients. For dot products, the variance grows with `d_k`. That's why the square root is used instead - we normalize the scale.
3. Convert the scores into a probability distribution to see how much attention should be given to each element: `Softmax = softmax(Scaling)`.
4. Multiply each value `V` by its attention `weight` and sum to get the final output: `Attention(Q, K, V) =  Softmax * V`.
This yields a context vector that highlights the most relevant information from `V` for the query `Q`.

**In short**: attention computes a weighted sum of input elements (values) where the weights are determined by a compatibility function between a query and corresponding keys: `Attention(Q, K, V) =  softmax(Q¬∑K·µÄ/‚àöd_k)¬∑V`

**Simplifications:**

> Imagine you're at a large party trying to focus on a specific conversation. You're asking yourself about each person: "How relevant is what this person is saying to what I want to know?" (computing attention scores). Then you focus more on people providing useful information (applying the attention weights) while still maintaining some awareness of everyone else. Your brain combines all this information, giving more weight to important sources (weighted sum of values).

Or

> A simple explanation: attention is just a dictionary with approximation. In a usual dictionary we have a pair of key-value and we pass a query to get a result. We either get the value of the key or nothing. In attention we get the answer even if we can't find the exact key.

## The essence of the dimensions d_k and d_v
`d_k` (**dimension of keys/queries**) and `d_v` (**dimension of values**) are not fixed values, but rather calculated parameters that are determined by the architecture of the attention mechanism and the optimization tasks.

- Self-Attention: Q, K, V from one source ‚Üí usually `d_k = d_v`
- Cross-Attention: Q from one source, K, V from another ‚Üí `d_k` and `d_v` can differ
- Multi-Head Attention: embedding is divided into heads ‚Üí `d_k = d_v = embedding_dim / num_heads`
- MultiQuery Attention: shared K, V for all heads ‚Üí memory savings

| Attention | Essence | Analogy | Pattern | Using | Memory |
|-----------|---------|---------|---------------------|------|-----|
|**Self-Attention**|Tokens of the same sequence look at each other|Group discussion | `Q,K,V = f(same_sequence)`|Text Comprehension (BERT)| Average
|**Cross-Attention** | Queries from sequence A, Keys/Values from sequence B | Student asks textbook | `Q = f(A)`, `K,V = f(B)` |Machine translation, chatbots|Average|
| **Multi-Head** | Many "experts" for different aspects | Team of experts | `split ‚Üí Attention_i ‚Üí concat` | Transformers (GPT, BERT)|High|
| **MultiQuery** | Shared K,V for all heads to save memory | One reference book for whole class | `K,V` shared across heads |Fast inference models (Llama 2)|Low|

More about types of Attentions check in the application (at the end of the current file)

### Multi-Head Attention
`d_k = d_v = embedding_dim // num_heads`, where:
- `embedding_dim` ‚Äî total embedding dimension (512, 768, 1024)
- `num_heads` ‚Äî number of attention heads (8, 12, 16)


### Optimization tasks
- Model quality: large `d_k/d_v` ‚Üí higher capacity
- Efficiency: small `d_k/d_v` ‚Üí faster computation
- Memory: MultiQuery ‚Üí less memory for `K`, `V`

## Code Solution


### Simple version without `causal_mask`

In [None]:
import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(q, k, v, mask=None):
  d_k = k.size(-1)

  # Step 1: Calculate Similarity
  scores = torch.matmul(q, k.transpose(-2, -1))

  # Step 2: Scaling
  scores = scores / math.sqrt(d_k) # in Pytorch: scores / math.sqrt(d_k)

  # Step 3: Masking (if provided)
  if mask is not None:
    # mask: True = block
    scores = scores.masked_fill(mask, float('-inf'))

  # Step 4: Softmax -> attention weights
  attn_weights = F.softmax(scores, dim=-1)

  # Step 5: Weighted sum of values
  output = torch.matmul(attn_weights, v)

  return output

# def create_causal_mask(seq_len, batch_size=1):
#   return torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()

Minimal, correct test:

- identical inputs
- `dropout_p=0.0`
- `is_causal=False`
- `eval()` mode is not needed (dropout is disabled)

In [None]:
import torch
import torch.nn.functional as F
import math


# =========================
# TEST
# =========================
torch.manual_seed(42)

B = 2   # batch size
L = 5   # sequence length
D = 8   # embedding dim

q = torch.randn(B, L, D)
k = torch.randn(B, L, D)
v = torch.randn(B, L, D)

out_custom = scaled_dot_product_attention(q, k, v)

out_torch = F.scaled_dot_product_attention(
    q, k, v,
    attn_mask=None,
    dropout_p=0.0,
    is_causal=False
)

# Check figures
torch.testing.assert_close(
    out_custom,
    out_torch,
    rtol=1e-5,
    atol=1e-6
)

print("‚úÖ Test passed: custom implementation matches PyTorch.")


‚úÖ Test passed: custom implementation matches PyTorch.


### Version with Casual mask
Removing future connections from a probability distribution

Why is masking applied before softmax:
- `exp(-inf) = 0`
- the weights of future tokens become `0`
- row sum = `1`

**–°ausal mask**
At position `t` the model has no right to look into the future `(t+1, t+2, ...)`:
```css
Posision 0 -> [0]
Posision 1 -> [0,1]
Posision 2 -> [0,1,2]
Posision 3 -> [0,1,2,3]
Posision 4 -> [0,1,2,3,4]
```

All elements to the right of the diagonal must be masked:
- `1` - Disable attention and
- `0` - Allow
```css
[[0, 1, 1, 1, 1],
 [0, 0, 1, 1, 1],
 [0, 0, 0, 1, 1],
 [0, 0, 0, 0, 1],
 [0, 0, 0, 0, 0]]
```

[torch.triu](https://docs.pytorch.org/docs/stable/generated/torch.triu.html):
- `torch.triu(..., diagonal=1)` - everything above the main diagonal
- `bool` - masked_fill awaiting

In [None]:
L = q.size(-2)

causal_mask = torch.triu(
    torch.ones(L, L, dtype=torch.bool),
    diagonal=1
)
causal_mask

tensor([[False,  True,  True,  True,  True],
        [False, False,  True,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True],
        [False, False, False, False, False]])

In [None]:
import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention_custom(
  q, k, v,
  attn_mask=None,
  is_causal=False
):

  d_k = k.size(-1)

  # Step 1: Calculate Similarity
  scores = torch.matmul(q, k.transpose(-2, -1))

  # Step 2: Scaling
  scores = scores / math.sqrt(d_k) # in Pytorch: scores / math.sqrt(d_k)

  # Step 3a: Masking (if provided)
  if is_causal:
    L = q.size(-2)
    causal_mask = torch.triu(
      torch.ones(L, L, device=q.device, dtype=torch.bool),
      diagonal=1 # True above the diagonal; Diagonal itself is NOT touched
    )
    # If causal_mask[t, j] == True: scores[t, j] = -‚àû -> softmax(-‚àû) = 0
    scores = scores.masked_fill(causal_mask, float('-inf'))

  # Step 3b: attention mask (padding / custom)
  if attn_mask is not None:
    scores = scores.masked_fill(attn_mask, float('-inf'))

  # Step 4: Softmax -> attention weights
  attn_weights = F.softmax(scores, dim=-1)

  # Step 5: Weighted sum of values
  output = torch.matmul(attn_weights, v)

  return output

Test: Custom mask - Comparison with PyTorch (causal)

In [None]:
torch.manual_seed(0)

B, L, D = 2, 6, 8

q = torch.randn(B, L, D)
k = torch.randn(B, L, D)
v = torch.randn(B, L, D)

out_custom = scaled_dot_product_attention_custom(
    q, k, v,
    is_causal=True
)

out_torch = F.scaled_dot_product_attention(
    q, k, v,
    is_causal=True,
    dropout_p=0.0
)

torch.testing.assert_close(
    out_custom,
    out_torch,
    rtol=1e-5,
    atol=1e-6
)

print("‚úÖ Causal attention test passed.")


‚úÖ Causal attention test passed.


#### Padding mask - Comparison with PyTorch (causal)

Since I am implementing logic that is fully compatible with PyTorch, it is important to note PyTorch internal logic:
1. `is_causal=True`
   - PyTorch **generates the causal mask itself**
   - `attn_mask` must be `None`
   - Passing both results in undefined behavior
2. `attn_mask`
   - Either bool mask (`True = disable`)
   - or additive mask (`0 / -inf`)
   - used for **padding/arbitrary** masking

In `F.scaled_dot_product_attention` is a hard contract:
```is_causal=True  ‚üπ  attn_mask MUST be None```. This is due to optimizations in FlashAttention kernels that require simplified logic (for performance):
 - is_causal=True enables specialized kernels
    - FlashAttention
    - Memory-efficient attention
 - These kernels cannot simultaneously:
    - causal masking
    - arbitrary mask

#### What is a padding mask (in practice)
Needs to ignore padding tokens in attention.
- all columns `PAD` ‚Üí `-inf`
- softmax will yield zeros

! The padding mask always masks keys,
not queries.

In [None]:
import torch

# 1 - real token, 0 - padding
padding_mask = torch.tensor([
    [1, 1, 1, 0, 0],
    [1, 1, 1, 1, 0],
])
# scores[b, i, j] = Q[b, i] ¬∑ K[b, j] -> [B, L_query, L_key]
attn_mask = padding_mask[:, None, :] == 0
padding_mask.shape, attn_mask.shape, padding_mask, attn_mask

(torch.Size([2, 5]),
 torch.Size([2, 1, 5]),
 tensor([[1, 1, 1, 0, 0],
         [1, 1, 1, 1, 0]]),
 tensor([[[False, False, False,  True,  True]],
 
         [[False, False, False, False,  True]]]))

In [None]:
scores = torch.zeros(2, 5, 5)
scores.masked_fill(attn_mask, float('-inf'))

tensor([[[0., 0., 0., -inf, -inf],
         [0., 0., 0., -inf, -inf],
         [0., 0., 0., -inf, -inf],
         [0., 0., 0., -inf, -inf],
         [0., 0., 0., -inf, -inf]],

        [[0., 0., 0., 0., -inf],
         [0., 0., 0., 0., -inf],
         [0., 0., 0., 0., -inf],
         [0., 0., 0., 0., -inf],
         [0., 0., 0., 0., -inf]]])

In [None]:
torch.manual_seed(0)

B, L, D = 2, 6, 8

q = torch.randn(B, L, D)
k = torch.randn(B, L, D)
v = torch.randn(B, L, D)

padding_mask = torch.tensor([
    [1, 1, 1, 0, 0, 0],
    [1, 1, 1, 1, 0, 0],
], dtype=torch.bool)

attn_mask = ~padding_mask[:, None, :]

out_custom = scaled_dot_product_attention_custom(
    q, k, v,
    attn_mask=attn_mask,
    is_causal=False
)

out_torch = F.scaled_dot_product_attention(
    q, k, v,
    attn_mask=attn_mask,
    dropout_p=0.0,
    is_causal=False
)

torch.testing.assert_close(
    out_custom,
    out_torch,
    rtol=1e-5,
    atol=1e-6
)

print("‚úÖ Causal attention test with Padding Mask passed.")


AssertionError: Tensor-likes are not close!

Mismatched elements: 96 / 96 (100.0%)
Greatest absolute difference: 2.1699323654174805 at index (1, 2, 4) (up to 1e-06 allowed)
Greatest relative difference: 105.48564910888672 at index (0, 1, 4) (up to 1e-05 allowed)

Wow 100% mismaching!
Probably, I had a symantic issue in comparing with Pytorch.

In my realization:
- `scores = scores.masked_fill(attn_mask, -inf)` -> `True` - DISABLED attention

In Pytorch (`F.scaled_dot_product_attention`:
- `attn_mask == True` ‚Üí ALLOWED
- `attn_mask == False` ‚Üí DISABLED

Because of that I have 100% mismatch:
 - I prohibit certain positions
 - PyTorch allows these positions


Why PyTorch did this:
- `attn_mask` was designed as an attention bias
- `True = allowed`, convenient for flash-attention/kernels
- `masked_fill` is simply a low-level API with different logic

Let inverted attn_mask in the realization for accept mask in the same semantics as PyTorch.

In [None]:
attn_mask_test = torch.tensor([True, True, False])
~attn_mask_test

tensor([False, False,  True])

In [None]:
import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention_custom(
  q, k, v,
  attn_mask=None,
  is_causal=False,
  dropout_p=0.0
):

  # API CONTRACT (PyTorch is not allow use is_causal=True and attn_mask)
  if is_causal and attn_mask is not None:
    raise RuntimeError(
      "Explicit attn_mask should not be set when is_causal=True"
    )

  d_k = k.size(-1)

  # Step 1: Calculate Similarity
  scores = torch.matmul(q, k.transpose(-2, -1))

  # Step 2: Scaling
  scores = scores / math.sqrt(d_k) # in Pytorch: scores / math.sqrt(d_k)

  # Step 3a: Masking (if provided)
  if is_causal:
    L = q.size(-2)
    causal_mask = torch.triu(
      torch.ones(L, L, device=q.device, dtype=torch.bool),
      diagonal=1 # True above the diagonal; Diagonal itself is NOT touched
    )
    # If causal_mask[t, j] == True: scores[t, j] = -‚àû -> softmax(-‚àû) = 0
    scores = scores.masked_fill(causal_mask, float('-inf'))

  # Step 3b: attention mask (padding / custom)
  if attn_mask is not None:
    scores = scores.masked_fill(~attn_mask, float('-inf'))

  # Step 4: Softmax -> attention weights
  attn_weights = F.softmax(scores, dim=-1)

  # Step 5: Weighted sum of values
  output = torch.matmul(attn_weights, v)

  return output


Test pedding

In [None]:

torch.manual_seed(0)

B, L, D = 2, 6, 8

q = torch.randn(B, L, D)
k = torch.randn(B, L, D)
v = torch.randn(B, L, D)

padding_mask = torch.tensor([
    [1, 1, 1, 0, 0, 0],
    [1, 1, 1, 1, 0, 0],
], dtype=torch.bool)

attn_mask = ~padding_mask[:, None, :]

out_custom = scaled_dot_product_attention_custom(
    q, k, v,
    attn_mask=attn_mask,
    is_causal=False
)

out_torch = F.scaled_dot_product_attention(
    q, k, v,
    attn_mask=attn_mask,
    dropout_p=0.0,
    is_causal=False
)

torch.testing.assert_close(
    out_custom,
    out_torch,
    rtol=1e-5,
    atol=1e-6
)

print("‚úÖ Causal attention test with Padding Mask passed.")

‚úÖ Causal attention test with Padding Mask passed.


Test causal + pedding

In [None]:
def assert_raises_same_error(fn1, fn2, *args, **kwargs):
    err1 = err2 = None

    try:
        fn1(*args, **kwargs)
    except Exception as e:
        err1 = e

    try:
        fn2(*args, **kwargs)
    except Exception as e:
        err2 = e

    assert err1 is not None, "Custom function did not raise"
    assert err2 is not None, "PyTorch function did not raise"

    assert type(err1) is type(err2), \
        f"Error types differ: {type(err1)} vs {type(err2)}"

    assert str(err1) in str(err2)

    print("‚úÖ Both implementations raise the same error")

assert_raises_same_error(
    scaled_dot_product_attention_custom,
    F.scaled_dot_product_attention,
    q, k, v,
    attn_mask=attn_mask,
    dropout_p=0.0,
    is_causal=True
)

‚úÖ Both implementations raise the same error


## Dropout in attention

`dropout_p` is a regularization tool important for stable training and preventing overfitting, especially on large models/long sequences.

- In traditional neural networks, dropout is used for regularization: parts of a neuron/activations are randomly "turned off" to **prevent the model from overfitting**.
- In the context of attention (especially in transformers), there is a similar effect called **attention-dropout**: **after** **softmax**, we obtain attention weights, and with some probability, some of these weights are set to zero. This makes the attention distribution "noisy" during training, which helps:
  - to prevent models from overfitting on the same keys/values,
  - to force attention to be distributed more evenly, which results in better generalization.
- In the official attention function in PyTorch, the dropout_p parameter is responsible for this operation: if dropout_p > 0.0, dropout is applied to the attention weights.

By Pytroch [scaled_dot_product_attention documentation](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html):

> Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.

That is, if `dropout_p > 0.0`, dropout is applied at the softmax weights stage. If `dropout_p = 0.0`, dropout is not applied.

Also, the PyTorch comments recommend using `dropout_p = 0.0` for **inference/eval** to **prevent accidental inference**.


In [41]:
import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention_custom(
  q, k, v,
  attn_mask=None,
  is_causal=False,
  dropout_p=0.0,
):

  # API CONTRACT (PyTorch is not allow use is_causal=True and attn_mask)
  if is_causal and attn_mask is not None:
    raise RuntimeError(
      "Explicit attn_mask should not be set when is_causal=True"
    )

  d_k = k.size(-1)

  # Step 1: Calculate Similarity
  scores = torch.matmul(q, k.transpose(-2, -1))

  # Step 2: Scaling
  scores = scores / math.sqrt(d_k) # in Pytorch: scores / math.sqrt(d_k)

  # Step 3a: Masking (if provided)
  if is_causal:
    L = q.size(-2)
    causal_mask = torch.triu(
      torch.ones(L, L, device=q.device, dtype=torch.bool),
      diagonal=1 # True above the diagonal; Diagonal itself is NOT touched
    )
    # If causal_mask[t, j] == True: scores[t, j] = -‚àû -> softmax(-‚àû) = 0
    scores = scores.masked_fill(causal_mask, float('-inf'))

  # Step 3b: attention mask (padding / custom)
  if attn_mask is not None:
    scores = scores.masked_fill(~attn_mask, float('-inf'))

  # Step 4: Softmax -> attention weights
  attn_weights = F.softmax(scores, dim=-1)

  # Step 5: Dropout
  if dropout_p > 0.0:
    attn_weights = torch.dropout(attn_weights, dropout_p, train=True)

  # Step 6: Weighted sum of values
  output = torch.matmul(attn_weights, v)

  return output


Test dropout

The values ‚Äã‚Äãwill differ slightly due to stochastic filtering

In [None]:
torch.manual_seed(0)
B, L, D = 2, 6, 8
q = torch.randn(B, L, D)
k = torch.randn(B, L, D)
v = torch.randn(B, L, D)

padding_mask = torch.tensor([
    [1,1,1,0,0,0],
    [1,1,1,1,0,0]
], dtype=torch.bool)
attn_mask = padding_mask[:, None, :]

out_custom = scaled_dot_product_attention_custom(
    q, k, v,
    attn_mask=attn_mask,
    dropout_p=0.1,
    is_causal=False
)

out_torch = F.scaled_dot_product_attention(
    q, k, v,
    attn_mask=attn_mask,
    dropout_p=0.1,
    is_causal=False
)
# Values may differ slightly due to stochastic dropout during training
print("Custom output with dropout:\n", out_custom)
print("PyTorch output with dropout:\n", out_torch)

Custom output with dropout:
 tensor([[[ 9.5632e-01, -1.4926e-01,  9.2037e-01,  8.6204e-01, -9.0496e-01,
           5.9019e-01, -2.8997e-01, -1.1777e+00],
         [ 5.7533e-01,  1.8178e-01,  6.7345e-01,  5.9659e-01, -7.1458e-01,
           7.9585e-01, -3.6383e-01, -5.3604e-01],
         [-1.8156e-01,  3.1829e-02,  1.7305e-01,  1.9385e-01,  1.3330e-01,
           1.2359e-01, -1.4741e-02,  2.4852e-01],
         [ 7.0718e-01, -3.4557e-01,  8.1082e-01,  8.1080e-01, -5.4638e-01,
           2.1007e-01, -9.5922e-02, -1.0049e+00],
         [ 3.8729e-01, -3.5764e-01,  6.9743e-01,  7.3303e-01, -2.2811e-01,
           5.9353e-02,  1.1549e-03, -6.3583e-01],
         [ 8.1055e-01, -2.5410e-01,  8.5741e-01,  8.3156e-01, -7.0109e-01,
           3.8186e-01, -1.8278e-01, -1.0705e+00]],

        [[-5.7556e-01, -5.4034e-02,  2.8969e-01, -7.2152e-01,  3.7089e-01,
          -8.4577e-01,  1.1345e+00,  8.5560e-01],
         [-5.3907e-02,  7.4537e-01,  3.3100e-01, -3.2346e-01, -1.9480e-01,
          -8.2233e-

## Multi-Head

- Runs multiple attention mechanisms in parallel
- Allows the model to jointly attend to information from different representation subspaces at different positions. Each head can potentially learn to focus on different types of relationships or features.

**Multi-Head Workflow**
1. Determine the number of heads: `num_heads`
2. Divide the embedding `d_model -> (num_heads, head_dim)`
3. Parallelize SDPA across all heads simultaneously
4. Broadcast masks on `[B, num_heads, L, L]`
5. Softmax + dropout for each head
6. Concatenate and linearly transform back to `d_model`

What does PyTorch nn.MultiheadAttention return?

https://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html?spm=a2ty_o01.29997173.0.0.38985171kySQjU


```python
# PyTorch returns:
# output: [batch_size, seq_len, embed_dim]
# attn_weights: [batch_size, seq_len, seq_len] (averaged across heads by default)
# or if average_attn_weights=False: [batch_size, num_heads, seq_len, seq_len]
```

In [36]:
# Initial parameters:
batch_size = 2 # B
seq_len = 5 # L
embed_dim = 512 # d_model
num_heads = 8 # H

# After head splitting:
head_dim = embed_dim // num_heads # 512 / 8 = 64

# Q, K, V before splitting: [B, L, d_model]
# After splitting: [B, L, H, head_dim]
# After transpose: [B, H, L, head_dim]

Step 1: Basic Structure of MultiHeadAttention

In [42]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
  def __init__(self, embed_dim, num_heads=1, dropout=0.0):
    """
    Args:
      embed_dim: total embedding dimension (d_model)
      num_heads: number of attention heads
      dropout: dropout probability
    """
    super().__init__()
    self.embed_dim = embed_dim
    self.num_heads = num_heads

    # Check: embed_dim must be divisible by num_heads
    assert embed_dim % num_heads == 0, \
      f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"

    # Size per head
    self.head_dim = embed_dim // num_heads

    # Linear layers for Q, K, V
    self.q_proj = nn.Linear(embed_dim, embed_dim)
    self.k_proj = nn.Linear(embed_dim, embed_dim)
    self.v_proj = nn.Linear(embed_dim, embed_dim)

    # Final Linear Layer
    self.out_proj = nn.Linear(embed_dim, embed_dim)

    self.dropout = dropout

  def forward(self, query, key, value, attn_mask=None, is_causal=False,
              need_weights=False, average_attn_weights=True):
    """
    Args:
      query: [batch_size, seq_len, embed_dim]
      key: [batch_size, seq_len, embed_dim]
      value: [batch_size, seq_len, embed_dim]
      attn_mask: optional mask [batch_size, seq_len, seq_len] or [batch_size, num_heads, seq_len, seq_len]
      is_causal: –µ—Å–ª–∏ True, –ø—Ä–∏–º–µ–Ω—è–µ—Ç—Å—è causal mask
      need_weights: If True, maintains attention weights.
      average_attn_weights: If True and need_weights=True, averages weights across heads.

    Returns:
      output: [batch_size, seq_len, embed_dim]
      attn_weights: if need_weights=True:
        - if average_attn_weights=True: [batch_size, seq_len, seq_len]
        - if average_attn_weights=False: [batch_size, num_heads, seq_len, seq_len]
    """
    batch_size = query.size(0)

    # 1. Linear transformations
    Q = self.q_proj(query)  # [B, L, d_model]
    K = self.k_proj(key)    # [B, L, d_model]
    V = self.v_proj(value)  # [B, L, d_model]

    # 2. Break it down into heads
    Q = self._split_heads(Q)  # [B, H, L, head_dim]
    K = self._split_heads(K)  # [B, H, L, head_dim]
    V = self._split_heads(V)  # [B, H, L, head_dim]

    # 3. Calculate the scaled dot-product attention for all heads
    attn_output, attn_weights = scaled_dot_product_attention(
      Q, K, V,
      attn_mask,
      is_causal,
      dropout_p=self.dropout,
      training=self.training
    )

    # 4. Put the heads back together
    attn_output = self._combine_heads(attn_output)  # [B, L, d_model]

    # 5. Final linear transformation
    output = self.out_proj(attn_output)

    # 6. Handling weights like in PyTorch
    if need_weights:
      if average_attn_weights:
        # Averaging over heads: [B, H, L, L] -> [B, L, L]
        attn_weights = attn_weights.mean(dim=1)
      return output, attn_weights
    else:
      return output, None  # Or just output if you want an exact match with PyTorch

  def _split_heads(self, x):
    """
    Splits the tensor into heads
    [B, L, d_model] -> [B, H, L, head_dim]
    """
    batch_size, seq_len, _ = x.size()

    # Reshape: [B, L, H, head_dim]
    x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)

    # Transpose: [B, H, L, head_dim] (for convenience matmul)
    x = x.transpose(1, 2)

    return x

  def _combine_heads(self, x):
    """
    Puts the heads back together
    [B, H, L, head_dim] -> [B, L, d_model]
    """
    batch_size, num_heads, seq_len, head_dim = x.size()

    # Transpose: [B, L, H, head_dim]
    x = x.transpose(1, 2)

    # Reshape: [B, L, d_model]
    x = x.contiguous().view(batch_size, seq_len, self.embed_dim)

    return x

Step 2: Modify scaled_dot_product_attention for Multi-Head

Single-Head Attention:
```python
# Q, K, V: [batch, seq_len, d_k]
d_k = k.size(-1) # for example, 512
# d_k - This is the full embedding dimension.
scores = scores / math.sqrt(d_k)  # devide to sqrt(512)
```

Multi-Head Attention:
```python
# Q, K, V: [batch, seq_len, d_k]
d_k = k.size(-1) # for example, 64 (512 / 8 heads)
# d_k - This is the full embedding dimension.
scores = scores / math.sqrt(d_k)  # devide to sqrt(64)
```

In the original article "Attention Is All You Need" the formula is:

`Attention(Q, K, V) =  softmax(Q¬∑K·µÄ/‚àöd_k)¬∑V`

But in Multi-Head:
- Each head handles a portion of the embedding.
- `d_k` in the formula refers to the KEY dimension in **ONE HEAD**.
- Not to the full embedding dimension!

Masking also changed and applyed for each head:

Single-Head Attention:
```python
# mask = [seq_len, seq_len] or [batch_size, seq_len, seq_len]

if is_causal:
  causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
  scores = scores.masked_fill(causal_mask, float('-inf'))
```

Multi-Head Attention:
```python
# mask = [batch_size, num_heads, seq_len, seq_len]

if is_causal:
  causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
  # 1. Add batch and head dimensions: [1, 1, seq_len, seq_len]
  causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
  # 2. Expand to batch_size and num_heads: [batch_size, num_heads, seq_len, seq_len]
  causal_mask = causal_mask.expand(batch_size, num_heads, -1, -1)
            
  scores = scores.masked_fill(causal_mask, float('-inf'))
```

As about **attention mask**

For Single-Head
- dim=2: [L, L] -> expand to [B, L, L]
- dim=3: [B, L, L] is already done
```python
if attn_mask.dim() == 2:          # [L, L]
  attn_mask = attn_mask.unsqueeze(0)  # [1, L, L]
  attn_mask = attn_mask.expand(batch_size, -1, -1)
elif attn_mask.dim() == 3:        # [B, L, L]
  pass
```

Multi-Head:
- dim=2: [L, L] -> expand to [B, H, L, L]
- dim=3: [B, L, L] -> expand to [B, H, L, L]
- dim=4: [B, H, L, L] is already done
```python
if attn_mask.dim() == 3:  # [B, L, L]
    attn_mask = attn_mask.unsqueeze(1)  # [B, 1, L, L]
    attn_mask = attn_mask.expand(-1, num_heads, -1, -1)
elif attn_mask.dim() == 2:  # [L, L]
    attn_mask = attn_mask.unsqueeze(0).unsqueeze(0)  # [1, 1, L, L]
    attn_mask = attn_mask.expand(batch_size, num_heads, -1, -1)
elif attn_mask.dim() == 4:  # [B, H, L, L]
    pass
```

In [44]:
def scaled_dot_product_attention(
  q, k, v,
  attn_mask=None,
  is_causal=False,
  dropout_p=0.0,
  training=True
  ):
  """
  Scaled dot-product attention –¥–ª—è multi-head
  Args:
    Q: [B, H, L, head_dim]
    K: [B, H, L, head_dim]
    V: [B, H, L, head_dim]
  Returns:
    output: [B, H, L, head_dim]
    attn_weights: [B, H, L, L]
  """
  # API CONTRACT (PyTorch is not allow use is_causal=True and attn_mask)
  if is_causal and attn_mask is not None:
    raise RuntimeError(
      "Explicit attn_mask should not be set when is_causal=True"
    )

  # Determine the mode: single-head or multi-head
  is_multi_head = (q.dim() == 4)

  if is_multi_head:
    batch_size, num_heads, seq_len, head_dim = q.size()
  else:  # single-head
    batch_size, seq_len, d_k = q.size()
    num_heads = 1  # –¥–ª—è —É–¥–æ–±—Å—Ç–≤–∞

  # 1. Calculate similarity (Q * K^T)
  scores = torch.matmul(q, k.transpose(-2, -1))

  # 2. Scaling
  scale_dim = q.size(-1) # d_k for single-head, head_dim for multi-head
  scores = scores / math.sqrt(scale_dim)

  # 3. Causal mask
  if is_causal:
    # Create a causal mask for each head
    causal_mask = torch.triu(
      torch.ones(seq_len, seq_len, device=q.device, dtype=torch.bool),
      diagonal=1
    )

    # Adapt the dimension to the input
    if is_multi_head:
      # Expand for batch and heads: [1, 1, L, L] -> [B, H, L, L]
      causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
      causal_mask = causal_mask.expand(batch_size, num_heads, -1, -1)
    else:
      # [1, L, L] -> [B, L, L]
      causal_mask = causal_mask.unsqueeze(0)
      causal_mask = causal_mask.expand(batch_size, -1, -1)

    scores = scores.masked_fill(causal_mask, float('-inf'))

  # 4. Attention mask
  if attn_mask is not None:
    # Check and adapt the mask size
    if is_multi_head:
      if attn_mask.dim() == 3:  # [B, L, L]
        attn_mask = attn_mask.unsqueeze(1)  # [B, 1, L, L]
        attn_mask = attn_mask.expand(-1, num_heads, -1, -1)
      elif attn_mask.dim() == 2:  # [L, L]
        attn_mask = attn_mask.unsqueeze(0).unsqueeze(0)  # [1, 1, L, L]
        attn_mask = attn_mask.expand(batch_size, num_heads, -1, -1)
      elif attn_mask.dim() == 4:  # [B, H, L, L]
        pass
      else:
        raise ValueError(f"Invalid attn_mask dim for multi-head: {attn_mask.dim()}")

    else: # single-head
      if attn_mask.dim() == 2:          # [L, L]
        attn_mask = attn_mask.unsqueeze(0)  # [1, L, L]
        attn_mask = attn_mask.expand(batch_size, -1, -1)
      elif attn_mask.dim() == 3:        # [B, L, L]
        pass
      else:
        raise ValueError(f"Invalid attn_mask dim for multi-head: {attn_mask.dim()}")

    # PyTorch semantics: True = allowed, False = masked
    scores = scores.masked_fill(~attn_mask, float('-inf'))

  # 5. Softmax
  attn_weights = F.softmax(scores, dim=-1)

  # 6. Dropout
  if dropout_p > 0.0 and training:
    #attn_weights = torch.dropout(attn_weights, dropout_p)
    attn_weights = F.dropout(attn_weights, dropout_p)

  # 7. Weighted sum
  output = torch.matmul(attn_weights, v)  # [B, H, L, head_dim]

  return output, attn_weights

Only bool attn_mask with True = allowed is supported

Test for Multihead

In [45]:
def test_multihead_simple():
  """A simple test for understanding dimensions"""
  print("üß™ Testing MultiHeadAttention dimensions")

  # Parameters
  batch_size = 2
  seq_len = 5
  embed_dim = 512
  num_heads = 8

  # Create the model
  mha = MultiHeadAttention(embed_dim, num_heads)

  # Test data
  query = torch.randn(batch_size, seq_len, embed_dim)
  key = torch.randn(batch_size, seq_len, embed_dim)
  value = torch.randn(batch_size, seq_len, embed_dim)

  # Forward pass
  output, attn_weights = mha(query, key, value, need_weights=True)

  print(f"Input query shape: {query.shape}")
  print(f"Output shape: {output.shape}")
  print(f"Attention weights shape: {attn_weights.shape}")

  # Checking dimensions
  assert output.shape == (batch_size, seq_len, embed_dim), \
    f"Expected output shape {(batch_size, seq_len, embed_dim)}, got {output.shape}"

  assert attn_weights.shape == (batch_size, seq_len, seq_len), \
        f"Expected weights shape {(batch_size, seq_len, seq_len)}, got {attn_weights.shape}"

  print("‚úÖ All dimensions are correct!")

  # Visualizing one head
  print(f"\nüëÅÔ∏è Example of an attention matrix (head 0, batch 0):")
  print(attn_weights[0, 0].detach().numpy().round(3))

  return output, attn_weights

test_multihead_simple()

üß™ Testing MultiHeadAttention dimensions
Input query shape: torch.Size([2, 5, 512])
Output shape: torch.Size([2, 5, 512])
Attention weights shape: torch.Size([2, 5, 5])
‚úÖ All dimensions are correct!

üëÅÔ∏è Example of an attention matrix (head 0, batch 0):
[0.224 0.198 0.195 0.186 0.197]


(tensor([[[ 0.1155, -0.1097,  0.1715,  ..., -0.0114, -0.0032, -0.0264],
          [ 0.1867, -0.1763,  0.2732,  ..., -0.0570,  0.1043, -0.1021],
          [ 0.1622, -0.1501,  0.2313,  ..., -0.0567,  0.0152, -0.0091],
          [ 0.0699, -0.1284,  0.2175,  ...,  0.0474,  0.0816, -0.0261],
          [ 0.1811, -0.0816,  0.2691,  ..., -0.0685,  0.0048,  0.0786]],
 
         [[ 0.1004, -0.1032,  0.1101,  ...,  0.0930, -0.0176, -0.0471],
          [ 0.0747, -0.0830,  0.0581,  ...,  0.1827,  0.0425,  0.0497],
          [ 0.0932, -0.0439,  0.1023,  ...,  0.1871, -0.0284,  0.0468],
          [ 0.0527, -0.0408,  0.0999,  ...,  0.0185,  0.0605,  0.0455],
          [ 0.0853, -0.0938,  0.0902,  ...,  0.1078,  0.0175,  0.0066]]],
        grad_fn=<ViewBackward0>),
 tensor([[[0.2241, 0.1982, 0.1951, 0.1860, 0.1966],
          [0.1781, 0.2107, 0.2200, 0.1852, 0.2061],
          [0.2072, 0.1958, 0.1532, 0.1889, 0.2548],
          [0.2040, 0.2107, 0.2155, 0.1605, 0.2092],
          [0.1800, 0.2434, 0.1852

Compare with Pytorch realisation

In [46]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def test_multihead_vs_pytorch():
  """
  Comparing our MultiHeadAttention implementation with PyTorch nn.MultiheadAttention
  """
  print("üß™ Testing MultiHeadAttention against PyTorch")
  print("="*60)

  # Test parameters
  torch.manual_seed(42)
  batch_size = 2
  seq_len = 6
  embed_dim = 64 # must be divisible by num_heads
  num_heads = 4

  # Check divisibility
  assert embed_dim % num_heads == 0, f"embed_dim {embed_dim} must be divisible by num_heads {num_heads}"

  # 1. Create test data
  query = torch.randn(batch_size, seq_len, embed_dim)
  key = torch.randn(batch_size, seq_len, embed_dim)
  value = torch.randn(batch_size, seq_len, embed_dim)

  print(f"Test parameters:")
  print(f" batch_size: {batch_size}")
  print(f" seq_len: {seq_len}")
  print(f" embed_dim: {embed_dim}")
  print(f" num_heads: {num_heads}")
  print(f" head_dim: {embed_dim // num_heads}")

  # 2. PyTorch MultiheadAttention
  # WARNING: PyTorch expects [seq_len, batch_size, embed_dim]
  pytorch_mha = nn.MultiheadAttention(
    embed_dim=embed_dim,
    num_heads=num_heads,
    batch_first=True, # important! for [batch, seq, embed]
    dropout=0.0 # for deterministic comparison
  )

  # Switch to eval mode (disable dropout)
  pytorch_mha.eval()

  #3. Our implementation of MultiHeadAttention
  our_mha = MultiHeadAttention(embed_dim, num_heads, dropout=0.0)

  # Copy weights from PyTorch to our model for accurate comparison
  with torch.no_grad():
    # Copy weights from PyTorch
    our_mha.q_proj.weight.copy_(pytorch_mha.in_proj_weight[:embed_dim, :])
    our_mha.q_proj.bias.copy_(pytorch_mha.in_proj_bias[:embed_dim])
    our_mha.k_proj.weight.copy_(pytorch_mha.in_proj_weight[embed_dim:2*embed_dim, :])
    our_mha.k_proj.bias.copy_(pytorch_mha.in_proj_bias[embed_dim:2*embed_dim])
    our_mha.v_proj.weight.copy_(pytorch_mha.in_proj_weight[2*embed_dim:, :])
    our_mha.v_proj.bias.copy_(pytorch_mha.in_proj_bias[2*embed_dim:])
    our_mha.out_proj.weight.copy_(pytorch_mha.out_proj.weight)
    our_mha.out_proj.bias.copy_(pytorch_mha.out_proj.bias)

    # Test 1
    print("\n1. Test Output comparison:")
    with torch.no_grad():
        pytorch_output = pytorch_mha(query, key, value, need_weights=False)[0]

    our_output = our_mha(query, key, value, need_weights=False)[0]

    output_diff = (our_output - pytorch_output).abs().max().item()
    print(f"  Output diff: {output_diff:.2e}")
    print(f"  Outputs match: {'‚úÖ' if output_diff < 1e-6 else '‚ùå'}")

    # Test 2
    print("\n2. Test with averaged weights:")
    with torch.no_grad():
        pytorch_output, pytorch_weights = pytorch_mha(
            query, key, value,
            need_weights=True,
            average_attn_weights=True
        )

    our_output, our_weights = our_mha(
        query, key, value,
        need_weights=True,
        average_attn_weights=True
    )

    weights_diff = (our_weights - pytorch_weights).abs().max().item()
    print(f"  Weights diff: {weights_diff:.2e}")
    print(f"  Weights match: {'‚úÖ' if weights_diff < 1e-6 else '‚ùå'}")

    # Test 3
    print("\n3. Per-head weights comparison:")
    with torch.no_grad():
      pytorch_output, pytorch_weights = pytorch_mha(
        query, key, value,
        need_weights=True,
        average_attn_weights=False
      )

    our_output, our_weights = our_mha(
      query, key, value,
      need_weights=True,
      average_attn_weights=False
    )

    weights_diff = (our_weights - pytorch_weights).abs().max().item()
    print(f"  Per-head weights diff: {weights_diff:.2e}")
    print(f"  Per-head weights match: {'‚úÖ' if weights_diff < 1e-6 else '‚ùå'}")

    print(f"\nüéâ {'ALL TESTS PASSED!' if output_diff < 1e-6 and weights_diff < 1e-6 else 'SOME TESTS FAILED'}")
    return output_diff < 1e-6 and weights_diff < 1e-6

    return True

test_multihead_vs_pytorch()

üß™ Testing MultiHeadAttention against PyTorch
Test parameters:
 batch_size: 2
 seq_len: 6
 embed_dim: 64
 num_heads: 4
 head_dim: 16

1. Test Output comparison:
  Output diff: 8.94e-08
  Outputs match: ‚úÖ

2. Test with averaged weights:
  Weights diff: 0.00e+00
  Weights match: ‚úÖ

3. Per-head weights comparison:
  Per-head weights diff: 0.00e+00
  Per-head weights match: ‚úÖ

üéâ ALL TESTS PASSED!


True

# Additional

## Types of Attentions

### Self-Attention
- Keys, queries, and values all come from the same source sequence
- Allows each position to attend to all positions in the sequence

### Cross-Attention
- The queries come from one sequence (e.g., the decoder in a seq2seq model), while the keys and values come from another (e.g., the encoder).
- Often used in machine translation and generative tasks where one sequence attends to another.

### Multi-Head Attention
- Runs multiple attention mechanisms in parallel
- Allows the model to jointly attend to information from different representation subspaces at different positions. Each head can potentially learn to focus on different types of relationships or features.

### MultiQuery Attention (MQA)
- All query heads share the same key and value matrices, only query matrices are different
- Significantly reduces memory requirements and inference time
- Can lead to quality degradation compared to MHA

### Grouped-Query Attention (GQA)
- Introduced to balance the efficiency of MQA and the quality in MHA
- In MHA, each query head has its own key-value heads (maximum quality but high memory usage). In MQA, all query heads share just one key-value head (maximum efficiency but lower quality). GQA divides query heads into groups, where each group shares a set of key-value heads.
- The number of groups (G) is a hyperparameter - more groups is closer to MHA, fewer is closer to MQA
- Used in models like Llama 2-70B, Mistral 7B, and Falcon 40B. Particularly useful in multi-GPU environments with tensor parallelism

### Global vs. Local Attention
- Global Attention attends to all positions in the sequence (standard approach). It helps maintain long-range dependencies that local attention might miss.
- Local Attention attends only to a window of positions around the current position. It reduces computational complexity from O(n¬≤) to O(n).
- Architectures like Longformer and BigBird use hybrid approaches combining both: local attention for most tokens, augmented with some form of global attention (specific tokens attending globally, or sparse global attention patterns) to retain the ability to capture long-range dependencies where needed.

### Multi-token attention
- Addresses limitations of single-token attention where individual weights are determined by similarity of just one query-key pair
- Applies convolution operations over queries, keys, and heads to allow neighboring tokens to influence each other's attention weights

# Resurses

* PyTorch Docs: [Dropout](https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html)
* [PyTorch Docs: scaled_dot_product_attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
* [Linformer: Self-Attention with Linear Complexity. (2020 Sinong Wang, Belinda Z. Li, Madian Khabsa, Han Fang, Hao Ma)](https://arxiv.org/abs/2006.04768)
* [Gated Attention for Large Language Models: Non-linearity, Sparsity, and Attention-Sink-Free. (2025 May, Zihan Qiu , Zekun Wang , Bo Zheng , Zeyu Huang , Kaiyue Wen , Songlin Yang , Rui Men , Le Yu , Fei Huang , Suozhi Huang , Dayiheng Liu , Jingren Zhou , Junyang Lin)](https://huggingface.co/papers/2505.06708)
* [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (2022 May, Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher R√©)](https://arxiv.org/abs/2205.14135)
* [Longformer: The Long-Document Transformer (2020, Iz Beltagy, Matthew E. Peters, Arman Cohan)](https://arxiv.org/abs/2004.05150)