Skip to content

Commit

Permalink
[Experimental] Prefix Caching Support (vllm-project#1669)
Browse files Browse the repository at this point in the history
Co-authored-by: DouHappy <2278958187@qq.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
  • Loading branch information
3 people authored and hongxiayang committed Jan 18, 2024
1 parent 11c1efb commit 327ccab
Show file tree
Hide file tree
Showing 20 changed files with 1,356 additions and 71 deletions.
4 changes: 4 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ steps:
- pytest -v -s models --forked
soft_fail: true

- label: Prefix Caching Test
commands:
- pytest -v -s prefix_caching

- label: Samplers Test
command: pytest -v -s samplers --forked

Expand Down
51 changes: 51 additions & 0 deletions examples/offline_inference_with_prefix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from vllm import LLM, SamplingParams

prefix = (
"You are an expert school principal, skilled in effectively managing "
"faculty and staff. Draft 10-15 questions for a potential first grade "
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
"community, joyful discovery, and life-long learning. The candidate is "
"coming in for a first-round panel interview for a 8th grade Math "
"teaching role. They have 5 years of previous teaching experience "
"as an assistant teacher at a co-ed, public school with experience "
"in middle school math teaching. Based on these information, fulfill "
"the following paragraph: ")

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")

generating_prompts = [prefix + prompt for prompt in prompts]

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(generating_prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

print("-" * 80)

# -1 since the last token can change when concatenating prompts.
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1

# Generate with prefix
outputs = llm.generate(generating_prompts, sampling_params,
prefix_pos=[prefix_pos] * len(generating_prompts))

# Print the outputs. You should see the same outputs as before
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
168 changes: 168 additions & 0 deletions tests/kernels/test_prefix_prefill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import random
import pytest
import time

import torch
from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
context_attention_fwd)
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask

NUM_HEADS = [12]
HEAD_SIZES = [128]
DTYPES = [torch.float16]


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_contexted_kv_attention(
num_heads: int,
head_size: int,
dtype: torch.dtype,
) -> None:
random.seed(0)
torch.manual_seed(0)
MAX_SEQ_LEN = 1024
MAX_CTX_LEN = 1024
BS = 10
cache_size = 640
block_size = 32
max_block_per_request = 64
subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)]

num_tokens = sum(subquery_lens)
query = torch.empty(num_tokens,
num_heads,
head_size,
dtype=dtype,
device='cuda')
query.uniform_(-1e-3, 1e-3)
output = torch.empty(num_tokens,
num_heads,
head_size,
dtype=dtype,
device='cuda')

kv = torch.empty(sum(seq_lens),
2,
num_heads,
head_size,
dtype=dtype,
device='cuda')
kv.uniform_(-1e-3, 1e-3)
key, value = kv.unbind(dim=1)

k_cache = torch.zeros(cache_size,
block_size,
num_heads,
head_size,
dtype=dtype,
device='cuda')
v_cache = torch.zeros(cache_size,
block_size,
num_heads,
head_size,
dtype=dtype,
device='cuda')
k = torch.zeros(sum(subquery_lens),
num_heads,
head_size,
dtype=dtype,
device='cuda')
v = torch.zeros(sum(subquery_lens),
num_heads,
head_size,
dtype=dtype,
device='cuda')
values = torch.arange(0, cache_size, dtype=torch.long, device='cuda')
values = values[torch.randperm(cache_size)]
block_table = values[:BS * max_block_per_request].view(
BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long, device='cuda')
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long, device='cuda')
b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1],
dtype=torch.long,
device='cuda'),
dim=0)
max_input_len = MAX_SEQ_LEN
# copy kv to cache
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
dtype=torch.long,
device='cuda'),
dim=0)
for i in range(BS):
for j in range(subquery_lens[i]):
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
j])
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
b_ctx_len[i] + j])
cur_ctx = 0
block_id = 0
while cur_ctx < b_ctx_len[i]:
start_loc = b_seq_start_loc[i] + cur_ctx
if cur_ctx + block_size > b_ctx_len[i]:
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
else:
end_loc = start_loc + block_size
start_slot = block_table[i, block_id] * block_size
end_slot = start_slot + end_loc - start_loc
k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(
key[start_loc:end_loc])
v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(
value[start_loc:end_loc])
cur_ctx += block_size
block_id += 1
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
k_cache = k_cache.view(-1, block_size, num_heads, head_size // 8,
8).permute(0, 2, 3, 1, 4).contiguous()
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_heads,
head_size).permute(0, 2, 3, 1).contiguous()

context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table,
b_start_loc, b_seq_len, b_ctx_len, max_input_len)
torch.cuda.synchronize()
start_time = time.time()
context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table,
b_start_loc, b_seq_len, b_ctx_len, max_input_len)
torch.cuda.synchronize()
end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")

scale = float(1.0 / (head_size**0.5))

attn_op = xops.fmha.cutlass.FwOp()

attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
subquery_lens, seq_lens)
output_ref = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=attn_op,
)
torch.cuda.synchronize()
start_time = time.time()
output_ref = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=attn_op,
)
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
output_ref = output_ref.squeeze(0)
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
41 changes: 41 additions & 0 deletions tests/prefix_caching/test_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Compare the with and without prefix caching.
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
import pytest

from vllm import LLM, SamplingParams

prefix = (
"You are an expert school principal, skilled in effectively managing "
"faculty and staff. Draft 10-15 questions for a potential first grade "
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
"community, joyful discovery, and life-long learning. The candidate is "
"coming in for a first-round panel interview for a 8th grade Math "
"teaching role. They have 5 years of previous teaching experience "
"as an assistant teacher at a co-ed, public school with experience "
"in middle school math teaching. Based on these information, fulfill "
"the following paragraph: ")


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("max_tokens", [16])
def test_prefix_caching(
example_prompts,
model: str,
max_tokens: int,
):
llm = LLM(model=model)
# -1 since the last token can change when concatenating prompts.
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1
prompts = [prefix + prompt for prompt in example_prompts]
sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs_without_prefix = llm.generate(prompts, sampling_params)
outputs_with_prefix = llm.generate(prompts,
sampling_params,
prefix_pos=[prefix_pos] * len(prompts))
for output_without_prefix, output_with_prefix in zip(
outputs_without_prefix, outputs_with_prefix):
assert (output_without_prefix.outputs[0].token_ids ==
output_with_prefix.outputs[0].token_ids)
assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 1
18 changes: 12 additions & 6 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def test_sampler_all_greedy(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
Expand Down Expand Up @@ -105,7 +106,8 @@ def test_sampler_all_random(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
Expand Down Expand Up @@ -140,7 +142,8 @@ def test_sampler_all_beam(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
Expand Down Expand Up @@ -193,7 +196,8 @@ def test_sampler_mixed(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
Expand Down Expand Up @@ -234,7 +238,8 @@ def pick_ith(token_ids, logits):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
Expand Down Expand Up @@ -288,7 +293,8 @@ def test_sampler_top_k_top_p(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)

sample_probs = None

Expand Down
5 changes: 3 additions & 2 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ def test_prepare_prompt():
expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_seq_len
input_tokens, input_positions, _, return_prompt_lens = (
input_tokens, input_positions, _, return_prompt_lens, _ = (
model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
assert input_tokens.shape == (batch_size, max_seq_len)
assert input_positions.shape == (batch_size, max_seq_len)
torch.testing.assert_close(input_tokens, input_positions)
Expand Down
4 changes: 4 additions & 0 deletions vllm/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,7 @@ def __repr__(self) -> str:
return (f'PhysicalTokenBlock(device={self.device}, '
f'block_number={self.block_number}, '
f'ref_count={self.ref_count})')


# Mapping: logical block number -> physical block.
BlockTable = List[PhysicalTokenBlock]
Loading

0 comments on commit 327ccab

Please sign in to comment.