# Attention Variants

## What is Attention?

* *Attention* is a mechanism that allows a model (usually a transformer) to focus selectively on the most relevant pieces of information when processing data. This is especially important in sequenced data like text.
* Instead of treating every input token equally, the model learns which tokens matter most for predicting the next word or generating meaning.
* This solves a major limitation of older models (like RNNs), which struggled to retain important information over long sequences.
* **Note:** Causal attention means the model is only allowed to attend to past and present tokens, never future tokens. It enforces a strict left‑to‑right information flow.

### How does Attention Work?
* Key components of the attention mechanism:
    * **Queries (Q)**: What the model is looking for.
    * **Keys (K)**: Labels that help the model decide which information is relevant.
    * **Values (V)**: The actual information that gets combined to form meaning.
* Process:
    * The model compares each query to all keys to compute attention scores.
    * These scores are normalized (using softmax) to form attention weights.
    * The model takes a weighted sum of the values, giving more weight to the most relevant tokens.

### Why is Attention Important?
* Attention captures long-range relationships (and variable length relationships) between tokens
* Parallelizable — unlike RNNs, attention processes all tokens at once, enabling high training efficiency, especially if the training is split across multiple cores of a CPU or GPU.
* Interpretable — attention weights offer insight into what the model “focuses on,” helping with debugging and analysis.
    * We will be focusing on the interpretability of the models we are using throughout this class.

## Multi-Headed Self Attention
* Self‑attention is a mechanism where each token in a sequence looks at all other tokens (including itself) to determine which ones are important for understanding its meaning.
* Each token produces Query (Q), Key (K), and Value (V) vectors.
* The relevance of each token compared to others is computed by comparing Q·K for every pair of tokens.
* **There is a built-in function in PyTorch.**

### Multi-Headed vs. Single-Headed
* A single attention operation can only learn one type of relationship at a time (e.g., syntactic, positional, semantic). 
* Multi‑headed attention solves this by running several attention operations in parallel—each with its own set of Q/K/V projections.
* Each head learns a different “view” of the sequence, for example:
    * Head 1 might track subject–verb agreement
    * Head 2 might focus on sentence structure
    * Head 3 might capture long‑distance dependencies
    * Head 4 might represent semantic parallels
* After processing, the outputs of all heads are concatenated and combined back into one combined representation.

### Steps

* The model creates h different sets of Q, K, V vectors from the input embeddings, where h is the number of heads.
* Each head performs the standard scaled dot‑product attention:


$$\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V$$


* The outputs of all heads (each of size $$d_k$$​, the dimensionality of the key vector) are concatenated into a single vector of size $$h \cdot d_k​$$.
* This combined vector is passed through one more linear layer to mix the information across heads.

### Attention is All You Need
* The famous paper *Attention is All You Need* proposed *scaled dot-product attention* which is what each head of multi-headed attention performs. 
* GPT-style models and DeepSeek use multi-headed attention as their attention mechanism.

## What are the Other Options?

## Flash Attention

FlashAttention is a fast, memory‑efficient, exact implementation of the multi-head attention mechanism. It does not approximate attention. Instead, it restructures how attention is computed so it runs dramatically faster and uses far less memory by being I/O‑aware. It changes how the vector and matrices are stored and retrieved in memory. The math is exactly the same as multi-headed. 

**Note:** This is not implemented in the notes due to the memory optimizations that need to occur requiring a GPU (and in the case of PyTorch, an Nvdia GPU).

**There is a built-in function in PyTorch.**

## Local Attention

* Local attention is an attention mechanism that limits each token (or position) to attend only to a small window of nearby tokens, rather than the entire sequence.
* Instead of calculating attention between every pair of positions (global attention, which costs $$O(n^2)$$ in time and memory), local attention only looks at a fixed neighborhood—dramatically reducing computation and memory (Local attention complexity: $$O(n \times w)$$, where w is the window size).
* Replace the softmax kernel with a feature map $$\phi$$ such that:
$$\text{softmax}(QK^\top) \approx \phi(Q)\phi(K)^\top$$
* Common feature map (exponential linear unit): $$\phi(x)=elu(x)+1$$
* In language, most dependencies are local. Constraining attention avoids flooding the model with global context when only local context matters.
* Local attention has the same mathematics as multi-head attention, but only over the local window.
* Local attention has lower computational usage and memory usage, but matches the local syntax of language and can reduce noise.

## Linear Attention
* Linear attention is a family of transformer attention mechanisms designed to eliminate the quadratic time and memory bottleneck of the standard softmax attention. 
* Instead of computing an nxn attention matrix (which becomes impossible for long sequences), linear attention restructures the math so that the attention computation scales linearly in sequence length.
* Similar math to local attention, with the function mapping, but no longer over just a sliding window.
* The function mapping can reduce long-range patterns so it can struggle with NLP but has been shown to work well with other forecasting tasks.

## Multi-Query Attention
* In standard multi‑head attention each head has its own Q, K, V vectors. When making predictions, the model must cache (store in memory) the past K and V for every head at every layer to avoid recomputing them for each generated token. 
* Multi-Query Attention keeps per‑head queries (Q) but uses a single shared set of keys and values (K, V) across all heads within a layer.
* KV cache size (the amount of memory needed to store the key and value vectors) is reduced by roughly a factor of #heads (e.g., 8× or 16×), yielding lower memory, higher throughput, and faster results.
* There is a trade-off in performance by multi-query attention does work well in several contexts.

## Additional Resources

### Multi-Headed Attention (Scaled Dot-Product Attention)
* [PyTorch Documentation](https://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html)
* [Attention is All You Need](https://arxiv.org/abs/1706.03762)
* [Multi-Head Attention](https://d2l.ai/chapter_attention-mechanisms-and-transformers/multihead-attention.html)
* [The Multi-head Attention Mechanism Explained!](https://newsletter.theaiedge.io/p/the-multi-head-attention-mechanism)

### Flash Attention
* [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (Original Paper)](https://arxiv.org/abs/2205.14135)
* [Basic idea behind flash attention (V1)](https://damek.github.io/random/basic-idea-behind-flash-attention/)
* [Flash Attention](https://huggingface.co/docs/text-generation-inference/en/conceptual/flash_attention)
* [Official Code Repo](https://github.com/Dao-AILab/flash-attention)

### Local Attention
* [Local Attention Mechanism: Boosting the Transformer Architecture for Long-Sequence Time Series Forecasting](https://arxiv.org/abs/2410.03805)
* [Two minutes NLP — Visualizing Global vs Local Attention](https://medium.com/nlplanet/two-minutes-nlp-visualizing-global-vs-local-attention-c61b42758019)
* [An Implementation on GitHub](https://github.com/lucidrains/local-attention)

### Linear Attention
* [Linear Attention Fundamentals](https://haileyschoelkopf.github.io/blog/2024/linear-attn/)
* [Transformers are RNNs:Fast Autoregressive Transformers with Linear  (Original Paper)](https://arxiv.org/pdf/2006.16236)
* [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/pdf/2312.06635)
* [An Implementation on GitHub](https://github.com/lucidrains/linear-attention-transformer)

### Multi-Query Attention
* [Fast Transformer Decoding: One Write-Head is All You Need (Original Paper)](https://arxiv.org/abs/1911.02150)
* [Multi-Query Attention Explained](https://pub.towardsai.net/multi-query-attention-explained-844dfc4935bf)
* [Multi-Query Attention is All You Need](https://fireworks.ai/blog/multi-query-attention-is-all-you-need)
* [A Gentle Introduction to Multi-Head Attention and Grouped-Query Attention](https://machinelearningmastery.com/a-gentle-introduction-to-multi-head-attention-and-grouped-query-attention/)
* [An Implementation on GitHub](https://github.com/kyegomez/MultiQueryAttention)