# vLLM Page KVCache

Blog: 

git: [dhcode-cpp/easy-infer](https://github.com/dhcode-cpp/easy-infer)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Dict, List, Set, Tuple, Optional, Any

torch.manual_seed(42)

<torch._C.Generator at 0x1197c0e30>

## Config

In [2]:
from dataclasses import dataclass

@dataclass
class vLLMEngineConfig:
    max_batch_size = 4
    max_seq_len = 32
    max_prompt_len: int = 16

    # model & kv cache
    num_layers = 3

    # PageKV Cache Setting
    page_size = 64
    num_pages = 1024

    dim = 16
    num_heads = 2
    head_dim = 8
    vocab_size = 20

config = vLLMEngineConfig()

# Request

In [3]:
EOS_TOKEN = 0
class Request:
    def __init__(self,
                 request_id: int,
                 prompt: List[int],
                 max_len: int = 2048):
        self.request_id = request_id
        self.prompt = prompt
        self.generated_tokens = []
        self.status = "REQUEST_WAITING"  # WAITING, RUNNING, COMPLETED
        self.current_length = len(prompt)
        self.max_length = max_len

    def add_token(self, token: int):
        """添加生成的token到请求中"""
        self.generated_tokens.append(token)
        self.current_length += 1
        if self.is_finished():
            self.status = "REQUEST_COMPLETED"
            print(
                f'finished: ID.{self.request_id}, new_len:{len(self.generated_tokens)}')

    def is_finished(self) -> bool:
        """检查请求是否完成(达到最大长度或生成了EOS)"""
        result = (self.current_length >= self.max_length or (
            self.generated_tokens and self.generated_tokens[-1] == EOS_TOKEN))
        return result

    def get_full_sequence(self) -> List[int]:
        """获取完整的序列(prompt + 生成的tokens)"""
        return self.prompt + self.generated_tokens


## Schedular

In [4]:
from queue import deque

class Schedular:
    """管理所有请求的调度和状态"""

    def __init__(self, max_seq_len: int = 1024):
        self.max_seq_len = max_seq_len
        self.requests = {}  # request_id -> Request
        self.waiting_queue = deque()
        self.running_requests = set()

    def add_request(self, prompt: List[int], max_seq_len: int) -> int:
        """添加新请求，返回请求ID"""
        request_id = len(self.requests)
        request = Request(request_id, prompt, max_seq_len)
        self.requests[request_id] = request
        self.waiting_queue.append(request_id)
        return request_id

    def get_available_request(self) -> int:
        """获取可用的批次空位数量"""
        return len(self.waiting_queue)

    def get_pending_requests(self, max_count: int) -> List[Tuple[int, List[int]]]:
        """获取等待处理的请求"""
        available_slots = self.get_available_request()
        count = min(max_count, available_slots, len(self.waiting_queue))

        requests_to_process = []
        for _ in range(count):
            if not self.waiting_queue:
                break
            request_id = self.waiting_queue.popleft()
            request = self.requests[request_id]
            request.status = "REQUEST_RUNNING"
            self.running_requests.add(request_id)
            requests_to_process.append((request_id, request.prompt))
        return requests_to_process

    def update_request(self, request_id: int, next_token: int):
        """更新请求状态"""
        if request_id in self.requests:
            request = self.requests[request_id]
            request.add_token(next_token)
            if request.is_finished():
                self.running_requests.discard(request_id)

    def has_pending_requests(self) -> bool:
        """检查是否有未完成的请求"""
        return len(self.waiting_queue) > 0 or len(self.running_requests) > 0

    def get_num_pending_requests(self) -> int:
        return len(self.waiting_queue)

    def get_num_running_requests(self) -> int:
        return len(self.running_requests)

    def get_running_request_ids(self) -> List[int]:
        """获取当前正在运行的请求ID"""
        return list(self.running_requests)

## PageKVCache 

1. BlockTable
2. PageKVCacheEngine

## BlockTable

In [5]:
class BlockTable:
    """逻辑块表管理 - 仅负责分页资源管理"""

    def __init__(self, page_size: int, num_pages: int):
        self.page_size = page_size
        self.num_pages = num_pages
        self.free_pages = list(range(num_pages))
        self.allocated_pages = set()

        self.page_usage = [0] * num_pages  
        self.next_page = [-1] * num_pages  

    def _allocate_pages(self, num_pages: int, parent_block_id=-1):
        """分配指定数量的页"""
        if len(self.free_pages) < num_pages:
            return None

        allocated = self.free_pages[:num_pages]
        self.free_pages = self.free_pages[num_pages:]
        self.allocated_pages.update(allocated)

        # 初始化块状态
        for page_id in allocated:
            self.page_usage[page_id] = 0
            self.next_page[page_id] = -1

        if parent_block_id != -1:
            self.next_page[parent_block_id] = allocated[0]

        return allocated

    def _free_pages(self, page_ids: list[int]):
        """释放页"""
        for page_id in page_ids:
            if page_id in self.allocated_pages:
                self.allocated_pages.remove(page_id)
                self.free_pages.append(page_id)
                self.page_usage[page_id] = 0
                self.next_page[page_id] = -1

    def get_free_count(self) -> int:
        """获取空闲块数量"""
        return len(self.free_pages)

## PageKVCacheEngine

In [6]:
class PageKVCacheEngine:
    """分页式管理 KV 缓存"""

    def __init__(self, config):
        # 初始化KV缓存 [layer, batch, seq, head, dim]

        self.num_pages = config.num_pages
        self.page_size = config.page_size

        self.block_table = BlockTable(self.page_size,
                                      self.num_pages,)

        # KV 的存储表由块表大小来管理
        self.k_cache = torch.zeros(config.num_layers,
                                   self.num_pages,
                                   self.page_size,
                                   config.num_heads,
                                   config.head_dim)
        self.v_cache = torch.zeros_like(self.k_cache)

        # 每个请求的长度
        self.sequence_lengths = {}

        # 请求与 block_id 的映射信息
        self.request_to_pages = {}  # request_id -> [page_id1, page_id2, ...]
        self.page_to_request = {}  # page_id -> request_id

    def has_active_requests(self) -> bool:
        """检查是否有活跃的请求"""
        return len(self.request_to_pages) > 0

    def has_available_pages(self, request_length) -> bool:
        """检查是否有可用的分页"""
        free_num_pages = self.block_table.get_free_count()
        return request_length < free_num_pages * self.page_size

    def allocate_request_pages(self, request_id, request_length) -> List[int]:
        """
        为 prefill 请求预分配页面, 对于正在解码的 decoding 请求，不需要预分配 page, 后续分配功能写在一起
        """

        allocate_pages_size = (request_length // self.page_size)+1
        allocate_pages_ids = self.block_table._allocate_pages(
            allocate_pages_size)
        if allocate_pages_ids == None:
            print(f'[ALLOCATE] request ID{request_id} pages faild')
            return []

        self.request_to_pages[request_id] = allocate_pages_ids

        for i in self.request_to_pages[request_id]:
            self.page_to_request[i] = request_id
        self.sequence_lengths[request_id] = 0

        print(
            f'[ALLOCATE] request ID{request_id} pages len {len(allocate_pages_ids)}')

        return allocate_pages_ids

    def free_request_pages(self, request_id: int):
        """释放请求占用的页面"""
        allocate_pages_ids = self.request_to_pages[request_id]
        self.block_table._free_pages([0, 1])

        self.k_cache[:, allocate_pages_ids, :, :, :] = 0
        self.v_cache[:, allocate_pages_ids, :, :, :] = 0

        del self.request_to_pages[request_id]
        for idx in allocate_pages_ids:
            del self.page_to_request[idx]
        print(
            f"[FREE] request ID{request_id} pages, len{len(allocate_pages_ids)}")

    def update_pages(self, request_id, new_kv_cache):
        # 1. Prefill: 填充到 requst_id -> pages 上
        # 2. Decoding: 找到 requst_id -> pages 上的最后一个 token, 如果最后一个块已满，需要重新申请一个新的块表。

        # 1. Update Decoding blocks
        # 对于 Decoding 存在增加一个新 token 导致要新加一个 block 的情况
        seq_len, _, _ = new_kv_cache[0][0].shape  # 0 层数据 K数据
        T = self.page_size

        if seq_len == 1:
            length = self.sequence_lengths[request_id]
            pages_ids = self.request_to_pages[request_id]
            if length % T == 0:
                new_block_id = self.block_table._allocate_pages(
                    1, pages_ids[-1])
                self.request_to_pages[request_id].append(new_block_id)
                self.page_to_request[new_block_id] = request_id
            self.sequence_lengths[request_id] += 1
            cur_offset_len = self.sequence_lengths[request_id] % T
        else:
            # prefill 填充 context 长度
            self.sequence_lengths[request_id] = seq_len

        # 2. Update Decoding Stage KV-Cache
        if seq_len == 1:
            for i, layer_kv_cache in enumerate(new_kv_cache):
                # seq_len, num_heads, head_dim = layer_kv_cache[0].shape
                pages_ids = self.request_to_pages[request_id]
                cur_offset_len = self.sequence_lengths[request_id] % T

                # 最后一个块上加 Cache
                self.k_cache[i, pages_ids[-1], cur_offset_len,
                             :, :] = layer_kv_cache[0][0, :, :]
                self.v_cache[i, pages_ids[-1], cur_offset_len,
                             :, :] = layer_kv_cache[1][0, :, :]

        # 3. Update Prefill Stage KV-Cache
        else:
            for i, layer_kv_cache in enumerate(new_kv_cache):
                pages_ids = self.request_to_pages[request_id]
                for k, idx in enumerate(pages_ids):
                    if k != len(pages_ids)-1:
                        self.k_cache[i, idx, :, :,
                                     :] = layer_kv_cache[0][k*T: (k+1)*T]
                        self.v_cache[i, idx, :, :,
                                     :] = layer_kv_cache[1][k*T: (k+1)*T]
                    else:  # 最后一个 block
                        cur_len = self.sequence_lengths[request_id] % T

                        self.k_cache[i, idx, :cur_len, :,
                                     :] = layer_kv_cache[0][k*T: k*T+cur_len]
                        self.v_cache[i, idx, :cur_len, :,
                                     :] = layer_kv_cache[1][k*T: k*T+cur_len]

    def get_sequence_kvcache(self, request_ids: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
        """获取请求列表的batch KVCache 数据"""
        T = self.page_size

        max_pages = -1
        for idx in request_ids:
            max_pages = max(len(self.request_to_pages[idx]), max_pages)

        num_layers, _, _, H, D = self.k_cache.shape

        K = torch.zeros(num_layers, len(request_ids), max_pages * T, H, D)
        V = torch.zeros(num_layers, len(request_ids), max_pages * T, H, D)

        for t, idx in enumerate(request_ids):
            page_ids = self.request_to_pages[idx]

            cur_length = len(page_ids) * T

            K[:, t, :cur_length, :, :] = self.k_cache[:, page_ids,
                                                      :, :, :].reshape(num_layers, cur_length, H, D)
            V[:, t, :cur_length, :, :] = self.v_cache[:, page_ids,
                                                      :, :, :].reshape(num_layers, cur_length, H, D)

        return (K, V)

    def get_request_info(self, ):

        for request_id, page_ids in self.request_to_pages.items():
            print(f'----Req.ID:{request_id}----')
            print('pages list:', page_ids)
            print('cur_length:', self.sequence_lengths[request_id])

## Model

Standard Attention, next-part will implemented PageAttention Kernel with PyTorch

In [7]:
import math
class DecoderBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_heads = config.num_heads
        self.dim = config.dim
        self.head_dim = config.head_dim
        self.WQ = nn.Linear(config.dim, config.dim, bias=False)
        self.WK = nn.Linear(config.dim, config.dim, bias=False)
        self.WV = nn.Linear(config.dim, config.dim, bias=False)
        self.WO = nn.Linear(config.dim, config.dim, bias=False)
        self.act = nn.ReLU()

    def forward(self, X, kvcache=None, current_length=None):
        bsz, seq_len, _ = X.shape
        Q, K, V = self.WQ(X), self.WK(X), self.WV(X)
        Q = Q.reshape(bsz, seq_len, self.num_heads,
                      self.head_dim).transpose(1, 2)
        K = K.reshape(bsz, seq_len, self.num_heads, self.head_dim)
        V = V.reshape(bsz, seq_len, self.num_heads, self.head_dim)

        if kvcache is None:
            K_, V_ = K, V
        else:
            K_cache = kvcache[0]
            V_cache = kvcache[1]

            # # left padding 拼接方式
            # K_ = torch.cat((K_cache, K), dim = 1)
            # V_ = torch.cat((V_cache, V), dim = 1)

            # right padding 填充方式
            cache_col = torch.zeros(bsz, 1, self.num_heads, self.head_dim)
            K_ = torch.cat((K_cache, cache_col.clone()), dim=1)
            V_ = torch.cat((V_cache, cache_col.clone()), dim=1)
            K_[:, current_length, :, :] = K
            V_[:, current_length, :, :] = V

        K_ = K_.transpose(1, 2)
        V_ = V_.transpose(1, 2)

        S = Q@K_.transpose(2, 3)//math.sqrt(self.head_dim)
        P = F.softmax(S, dim=-1)
        Z = P@V_
        Z = Z.transpose(1, 2).reshape(bsz, seq_len, self.dim)
        O = self.WO(Z)

        # activate & shorcut
        O_ = X + self.act(O)

        return O_, (K, V)

In [8]:
class ToyModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embd = nn.Embedding(config.vocab_size, config.dim)
        self.lm_head = nn.Linear(config.dim, config.vocab_size)
        self.decoder = nn.ModuleList(
            [DecoderBlock(config) for i in range(config.num_layers)]
        )

    def forward(self, x, kvcaches=None, current_length=None):
        layer_kvcaches = []
        X = self.embd(x)

        for i, block in enumerate(self.decoder):
            if kvcaches == None:
                X, layer_kvcache = block(X, None, None)
            else:
                X, layer_kvcache = block(X,
                                         kvcache=[kvcaches[0][i],
                                                  kvcaches[1][i]],
                                         current_length=current_length)

            layer_kvcaches.append(layer_kvcache)
        logits = self.lm_head(X)
        return logits, layer_kvcaches

## ModelWrapper

In [10]:

class ModelWrapper:
    """封装模型的前向传播"""

    def __init__(self, model, kv_cache_manager: PageKVCacheEngine):
        self.model = model
        self.cacher = kv_cache_manager

    def prefill_requests(self, request_ids: List[int], prompts: List[List[int]]) -> Tuple[torch.Tensor, Any]:
        """预填充新请求"""
        if len(request_ids) == 0:
            return torch.tensor([])

        # 预分配页面
        for request_id, prompt in zip(request_ids, prompts):
            self.cacher.allocate_request_pages(request_id,  len(prompt))

        # padding
        current_lens = [len(prompt) for prompt in prompts]
        max_len = max(current_lens)
        input_ids = torch.zeros(len(prompts), max_len, dtype=torch.long)
        for i, prompt in enumerate(prompts):
            input_ids[i, :len(prompt)] = torch.tensor(prompt)

        # 执行预填充
        with torch.no_grad():
            logits, layer_kvcaches = self.model(input_ids,
                                                # current_length=current_len
                                                )
        L = [ l-1 for l in current_lens]
        return logits[:, L, :], layer_kvcaches

    def decode_next_tokens(self, next_tokens: torch.Tensor, current_length=None, KVCache=None) -> Tuple[torch.Tensor, Any]:
        """解码下一个token"""

        with torch.no_grad():
            logits, layer_kvcaches = self.model(
                next_tokens,
                kvcaches=KVCache,
                current_length=current_length,
            )

        return logits, layer_kvcaches

    def generate_next_tokens(self, logits: torch.Tensor) -> torch.Tensor:
        """从logits生成下一个token(贪婪采样)"""
        if len(logits) == 0:
            return torch.tensor([])
        return torch.argmax(logits, dim=-1)

## vLLM 类

In [11]:
class vLLMPageCacheEngine:
    """vLLM 主引擎"""

    def __init__(self, model, config):
        self.cacher = PageKVCacheEngine(config)
        self.model_wrapper = ModelWrapper(model, self.cacher)
        self.schedular = Schedular(config.max_seq_len,)

    def add_request(self, prompt: List[int], max_seq_len) -> int:
        """添加新请求"""
        return self.schedular.add_request(prompt, max_seq_len)

    def step(self):
        """
        """
        request_ids = None
        layer_kvcaches = None
        # 阶段1: 处理解码(已有请求)
        if self.schedular.get_num_running_requests() > 0:
            
            request_ids = self.schedular.get_running_request_ids()


            # 准备输入token (上一个 step 生成的token)
            input_tokens = torch.tensor([
                self.schedular.requests[req_id].generated_tokens[-1]
                for req_id in request_ids
            ], dtype=torch.long)
            current_length = torch.tensor([
                self.schedular.requests[req_id].current_length
                for req_id in request_ids
            ], dtype=torch.long)
            input_tokens = input_tokens.unsqueeze(dim=1)

            # Page KVCache -> Batch KVCache
            batch_kvcache = self.cacher.get_sequence_kvcache(request_ids)

            # 解码
            logits, layer_kvcaches = self.model_wrapper.decode_next_tokens(input_tokens,
                                                                           KVCache=batch_kvcache,
                                                                           current_length=current_length)
            next_tokens = self.model_wrapper.generate_next_tokens(logits)

            # update kv cache
            self.update_kvcache(request_ids, layer_kvcaches)

            # 更新状态
            for i, request_id in enumerate(request_ids):
                self.schedular.update_request(
                    request_id, next_tokens[i].item())
                if self.schedular.requests[request_id].is_finished():
                    self.cacher.free_request_pages(request_id)

        # 阶段2: 处理预填充(新请求)
        if self.schedular.get_num_pending_requests() > 0:
            pending_requests = self.schedular.get_pending_requests(
                config.num_pages
            )
            request_ids = [idx for idx, _ in pending_requests]
            prompts = [prompt for _, prompt in pending_requests]

            if pending_requests:
                logits, layer_kvcaches = self.model_wrapper.prefill_requests(
                    request_ids, prompts)
                next_tokens = self.model_wrapper.generate_next_tokens(logits)
                
                for i, (request_id, _) in enumerate(pending_requests):
                    self.schedular.update_request(
                        request_id, next_tokens[0,i].item())

                self.update_kvcache(request_ids, layer_kvcaches)

    def update_kvcache(self, request_ids, layer_kvcaches):
        if request_ids != None:
            for i, idx in enumerate(request_ids):
                tmp_cache = [[layer_kvcache[0][i], layer_kvcache[1][i]]
                             for layer_kvcache in layer_kvcaches]
                self.cacher.update_pages(idx, tmp_cache)

    def has_pending_work(self) -> bool:
        """检查是否还有未完成的工作"""
        return self.schedular.has_pending_requests()

    def get_requests_info(self):
        pending = self.schedular.get_num_pending_requests()
        running = self.schedular.get_num_running_requests()
        total_request = len(self.schedular.requests)
        return pending, running, total_request

## 主函数

In [12]:
import random
random.seed(42)  


def listen_request(config, p=0.01):
    prompt=[]
    prompt_len=0
    num = random.randint(1,100)
    if num/100.0 < p:
        prompt_len = random.randint(config.max_prompt_len//4, config.max_prompt_len)
        prompt=torch.randint(low=1, high=config.vocab_size, size=(1, prompt_len))
        prompt=prompt[0].tolist()
    return prompt, prompt_len

In [13]:
N = 32
count = 0

model = ToyModel(config)
engine = vLLMPageCacheEngine(model, config)

# main 
while 1:
    # 监听进程
    if count != N:
        prompt, prompt_len = listen_request(config, p=0.5)
        if prompt_len != 0:
            count += 1
            if count % (N//10) == 0:
                per = count / (N//10) 
                print('Running...:','*'*int(per),'-'*(10-int(per)))
            generate_len = random.randint(prompt_len, config.max_seq_len)
            engine.add_request(prompt, generate_len)
            pending, running, total = engine.get_requests_info()
            print(f'[Request Info] pending:{pending}/running:{running}/total:{total}')
            
    # 处理进程   
    engine.step()

    if not engine.has_pending_work() and count == N:
        print('process done')
        break

[Request Info] pending:1/running:0/total:1
[ALLOCATE] request ID0 pages len 1
[Request Info] pending:1/running:1/total:2
[ALLOCATE] request ID1 pages len 1
Running...: * ---------
[Request Info] pending:1/running:2/total:3
[ALLOCATE] request ID2 pages len 1
finished: ID.2, new_len:3
[FREE] request ID2 pages, len1
[Request Info] pending:1/running:2/total:4
[ALLOCATE] request ID3 pages len 1
[Request Info] pending:1/running:3/total:5
finished: ID.1, new_len:7
[FREE] request ID1 pages, len1
[ALLOCATE] request ID4 pages len 1
Running...: ** --------
[Request Info] pending:1/running:3/total:6
finished: ID.4, new_len:2
[FREE] request ID4 pages, len1
[ALLOCATE] request ID5 pages len 1
[Request Info] pending:1/running:3/total:7
[ALLOCATE] request ID6 pages len 1
finished: ID.6, new_len:6
[FREE] request ID6 pages, len1
[Request Info] pending:1/running:3/total:8
[ALLOCATE] request ID7 pages len 1
Running...: *** -------
[Request Info] pending:1/running:4/total:9
[ALLOCATE] request ID8 pages len 