# Key-value cache and Paged Attention

This notebook made by [n.luneva](https://github.com/lwtztea)

## Key-Value Cache

### Theory

https://huggingface.co/docs/transformers/main/kv_cache

Key-Value Cache is an optimization used in transformer models to speed up text generation in **autoregressive tasks**, such as machine translation, text generation or question answering. In these tasks, the model generates a sequence of tokens **one at the time**, where each new token depends on previously generated tokens.

![](https://miro.medium.com/v2/resize:fit:1400/format:webp/1*uyuyOW1VBqmF5Gtv225XHQ.gif)

#### Main Idea

In the operation of a transformer, each layer uses the attention mechanism to compute representations of the current token based on the context from previous tokens. For this, the attention mechanism performs the following steps:

1. **Compute Query.** A query is computed for the current token.
2. **Compute Key and Value.** Keys and values are computed for all input tokens (including the current one).
3. **Compute Attention.** Based on the query and keys, weights are calculated, which are then used for the weighted summation of values.

During autoregressive text generation (token by token), many computations are repeated. For example, if the model has already processed the first 100 tokens, when generating the 101st token, it is unnecessary to recompute keys and values for the first 100 tokens, as they remain unchanged. This is where the Key-Value Cache comes into play.

#### How Key-Value Cache Works?

1. **Caching Keys and Values:**
   - When the model processes the first token in the sequence, it computes keys and values for all available tokens (including the first one).
   - These keys and values are stored in the cache.
   - When processing the second token, the keys and values for the first token are retrieved from the cache, and only the keys and values for the second token are newly computed.
   - This process continues for each new token: keys and values for previously processed tokens are retrieved from the cache, while those for the new token are computed anew.
2. **Reducing Computational Load:**
   - Without caching, at each generation step, the model would have to recompute keys and values for all previous tokens, leading to significant computational overhead.
   - With caching, the amount of computation is greatly reduced since the keys and values for older tokens are not recomputed.
3. **Memory Usage:**
   - The cache is stored in memory, and its size grows linearly with the length of the input sequence.
   - This requires additional memory but significantly speeds up the generation process.

#### Advantages of Key-Value Cache

1. **Faster Generation.** Caching avoids redundant computations of keys and values for already processed tokens, significantly speeding up the generation process.
2. **Computational Resource Savings.** Reduces the number of matrix multiplication operations, which is especially important for large models with many parameters.
3. **Scalability.** Suitable for handling long sequences, where repeated calculations would become prohibitively expensive.

#### Drawbacks of Key-Value Cache

1. **Memory Usage**:
   - The cache requires additional memory to store keys and values for all processed tokens.
   - The size of the cache grows linearly with the length of the input sequence, which can become an issue for very long sequences.
2. **Limitations for Some Tasks**:
   - If the input sequence changes dynamically (e.g., in streaming processing), managing the cache can become more complex.

### Practice

#### Loading Data and Model

In [1]:
!pip install datasets

In [2]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

In [3]:
# Load model
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Load data
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

README.md: 0.00B [00:00, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [4]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [5]:
text = dataset["text"][3]
text = text[:100]  # get text prefix
print(text)

 Robert Boulter is an English film , television and theatre actor . He had a guest @-@ starring role


#### Implementation

In [6]:
class KeyValueCache:
    def __init__(self, model):
        self.model = model
        self.cache = {}

    def generate_with_cache(self, input_ids, max_length=50):
        past_key_values = None
        generated_tokens = []

        for _ in range(max_length):
            with torch.no_grad():
                outputs = self.model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True)
                logits = outputs.logits[:, -1, :]
                past_key_values = outputs.past_key_values

                next_token = torch.argmax(logits, dim=-1)
                generated_tokens.append(next_token.item())

                input_ids = next_token.unsqueeze(0)

        return tokenizer.decode(generated_tokens), outputs

In [7]:
# Initialize kv-cache
cache_model = KeyValueCache(model)

# Generate using kv-cache
input_ids = tokenizer.encode(text, return_tensors="pt")
generated_text, outputs = cache_model.generate_with_cache(input_ids)

print(f"Input text: {text}")
print(f"Generated text: {generated_text}")

Input text:  Robert Boulter is an English film , television and theatre actor . He had a guest @-@ starring role
Generated text:  in the film, and he is a member of the cast of the film. He is also a member of the cast of the film.

He is a member of the cast of the film, and he is a member of the cast of


In [8]:
len(outputs.past_key_values)  # attention blocks

12

In [9]:
len(outputs.past_key_values[0])  # keys and values cache

2

In [10]:
outputs.past_key_values[0][0].shape  # [batch_size, num_heads, seq_len, head_dim]

torch.Size([1, 12, 72, 64])

In [11]:
outputs.past_key_values[0][1][0][0]

tensor([[ 0.0223,  0.0035,  0.0491,  ...,  0.0357,  0.0884,  0.0652],
        [-0.0326, -0.1449,  0.0364,  ...,  0.1332,  0.0875,  0.0343],
        [-0.0741, -0.1271, -0.0747,  ..., -0.2468,  0.0759,  0.0554],
        ...,
        [-0.0129, -0.1281, -0.0717,  ..., -0.0452, -0.0520,  0.2178],
        [-0.1865, -0.0098,  0.1319,  ...,  0.0968,  0.0990, -0.1065],
        [ 0.4145,  0.1236,  0.0328,  ..., -0.1359, -0.1635, -0.0631]])

## Paged attention

### Theory

https://blog.vllm.ai/2023/06/20/vllm.html


<img src = https://raw.githubusercontent.com/lwtztea/ml_pic/957bf7b/week_6/memory_wastes.png width = 2000 >

![](https://blog.vllm.ai/assets/figures/annimation1.gif)

#### Problem — Managing Attention in Transformers

In traditional transformer architectures, the attention mechanism requires storing and processing keys and values for each position in the sequence. For a sequence of length N, this leads to the creation of matrices of size O(N×d), where d is the dimensionality of the hidden space. When N becomes very large (e.g., thousands or millions of tokens), the amount of memory required to store these matrices grows linearly, which can become overwhelming.

Additionally, when using mechanisms like KV-caching (caching keys and values for reuse in autoregressive models), **memory quickly fills up**, especially if the model is working with long contexts.

#### Solution — Paged Attention

Paged Attention proposes an approach inspired by **memory management principles in operating systems**, where data is divided into fixed-size blocks (pages). The idea is to split large matrices of keys and values into smaller "pages" of fixed size and manage them more efficiently.

#### Key Ideas Behind Paged Attention:

1. **Splitting into Pages:**
   - The matrices of keys and values are divided into small blocks (pages) of fixed size.
   - For example, if the sequence length N=1024 and the page size P=64, the matrix will be split into 1024/64=16 pages.
2. **Memory Management:**
   - Instead of storing the entire matrix in RAM, only active pages (those needed for current computations) are loaded into fast memory (e.g., GPU RAM).
   - Inactive pages can be moved to slower memory (e.g., CPU RAM or even disk).
3. **Optimized Access:**
   - When the model processes a specific position in the sequence, it loads only the pages that contain the corresponding keys and values.
   - This significantly reduces memory usage and avoids overloading GPU RAM.
4. **Support for Long Sequences:**
   - Thanks to paging, the model can handle sequences that far exceed the available memory of the device.

### Practice

Toy implementation of Paged Attention.

In [12]:
class PagedAttention:
    def __init__(self, page_size=128):
        self.page_size = page_size
        self.pages = {}

    def add_to_page(self, key, value, page_id):
        if page_id not in self.pages:
            self.pages[page_id] = {"keys": [], "values": []}
        self.pages[page_id]["keys"].append(key)
        self.pages[page_id]["values"].append(value)

    def get_attention(self, query, page_id):
        if page_id not in self.pages:
            raise ValueError("Page not found!")

        keys = torch.stack(self.pages[page_id]["keys"])
        values = torch.stack(self.pages[page_id]["values"])
        attention_scores = torch.matmul(query, keys.transpose(-2, -1))
        attention_weights = torch.softmax(attention_scores, dim=-1)
        output = torch.matmul(attention_weights, values)
        return output

In [13]:
paged_attention = PagedAttention(page_size=128)

# q,k,v examples
query = torch.randn(1, 64)
key = torch.randn(1, 64)
value = torch.randn(1, 64)

paged_attention.add_to_page(key, value, page_id=0)
output = paged_attention.get_attention(query, page_id=0)
print(f"Output attention: {output}")

Output attention: tensor([[[ 0.1459,  0.1059, -2.0250,  1.4762,  0.1633, -1.2413,  0.1044,
          -0.1903,  0.2401, -0.7890, -1.4566,  0.2000, -0.5124,  0.1269,
          -0.8701, -1.1005,  0.0757, -1.0219, -1.2644,  1.2178, -3.2895,
          -1.3653,  2.0364, -1.4599, -0.2574, -0.0435, -0.2313, -0.2819,
          -0.4846,  1.3289, -1.7133,  0.0908, -0.9375, -1.4625, -0.3167,
          -0.3357,  1.1040, -0.7233, -0.1881, -0.4465,  1.2115,  0.0487,
          -0.4799, -0.7031, -1.8934,  1.1073,  0.8395, -0.7706, -2.1159,
           0.1847, -2.5267, -1.0058, -0.6559, -0.3694,  1.4340, -0.0253,
           2.2207, -1.1776, -0.6141,  0.3837,  0.9933,  0.9670,  0.6336,
          -0.6800]]])
