Skip to content

Commit

Permalink
[Misc] Enhance attention selector (vllm-project#4751)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored and robertgshaw2-neuralmagic committed May 19, 2024
1 parent 270c0c2 commit c944527
Show file tree
Hide file tree
Showing 49 changed files with 573 additions and 220 deletions.
1 change: 0 additions & 1 deletion tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):

assert len(attn_metadata.slot_mapping) == len(input_tokens)
assert len(input_positions) == len(input_tokens)
assert attn_metadata.kv_cache_dtype == "auto"
assert attn_metadata.num_prefills == prefill_batch_size
if enforce_eager:
assert attn_metadata.num_decode_tokens == decode_batch_size
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from vllm.attention.selector import get_attn_backend

__all__ = [
"Attention",
"AttentionBackend",
"AttentionMetadata",
"Attention",
"get_attn_backend",
"AttentionMetadataPerStage",
"get_attn_backend",
]
5 changes: 2 additions & 3 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ class AttentionMetadata(Generic[T]):
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# The kv cache's data type.
kv_cache_dtype: str

def __post_init__(self):
if self.num_prefill_tokens > 0:
Expand All @@ -116,6 +114,7 @@ def __init__(
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
) -> None:
raise NotImplementedError

Expand All @@ -127,6 +126,6 @@ def forward(
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
kv_scale: float,
kv_scale: float = 1.0,
) -> torch.Tensor:
raise NotImplementedError
13 changes: 7 additions & 6 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,18 @@ def __init__(
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
Expand All @@ -167,7 +169,7 @@ def forward(
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata[FlashAttentionMetadata],
kv_scale: float,
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand Down Expand Up @@ -196,8 +198,7 @@ def forward(
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
kv_scale)
self.kv_cache_dtype, kv_scale)

num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
Expand Down Expand Up @@ -264,7 +265,7 @@ def forward(
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
Expand Down
33 changes: 23 additions & 10 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,33 @@ def __init__(
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
) -> None:
if sliding_window is not None:
raise ValueError("Sliding window is not supported in FlashInfer.")
self.sliding_window = (-1, -1)
self.alibi_slopes = alibi_slopes
self.scale = scale
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is not None:
raise ValueError("Sliding window is not supported in FlashInfer.")
self.sliding_window = (-1, -1)
self.kv_cache_dtype = kv_cache_dtype

def forward(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[FlashInferMetadata],
kv_scale: float):
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[FlashInferMetadata],
kv_scale: float = 1.0,
) -> torch.Tensor:
assert kv_scale == 1.0
num_tokens, hidden_size = query.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
Expand All @@ -183,7 +196,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
attn_metadata.kv_cache_dtype,
self.kv_cache_dtype,
)

if prefill_meta := attn_metadata.prefill_metadata:
Expand Down
16 changes: 9 additions & 7 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,25 +138,27 @@ def __init__(
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
supported_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in supported_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
f"Supported head sizes are: {supported_head_sizes}.")

self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
Expand Down Expand Up @@ -229,7 +231,7 @@ def forward(
key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
self.kv_cache_dtype,
kv_scale,
)

Expand Down Expand Up @@ -323,7 +325,7 @@ def forward(
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
Expand Down
28 changes: 17 additions & 11 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,26 +83,32 @@ def __init__(
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
if alibi_slopes is not None:
assert len(alibi_slopes) == num_heads
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)

supported_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in supported_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
f"Supported head sizes are: {supported_head_sizes}.")
if kv_cache_dtype != "auto":
raise NotImplementedError(
"Torch SDPA backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")

def forward(
self,
Expand All @@ -111,7 +117,7 @@ def forward(
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata, # type: ignore
kv_scale: float,
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Expand All @@ -124,6 +130,7 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert kv_scale == 1.0
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
Expand All @@ -136,8 +143,7 @@ def forward(
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
kv_scale)
self.kv_cache_dtype, kv_scale)

if attn_metadata.is_prompt:
assert attn_metadata.seq_lens is not None
Expand Down Expand Up @@ -195,7 +201,7 @@ def forward(
attn_metadata.block_tables,
attn_metadata.seq_lens_tensor,
attn_metadata.max_seq_len,
attn_metadata.kv_cache_dtype,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
Expand Down
12 changes: 6 additions & 6 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,17 @@ def __init__(
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
Expand All @@ -175,7 +177,7 @@ def forward(
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[XFormersMetadata],
kv_scale: float,
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand All @@ -188,7 +190,6 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
num_tokens, hidden_size = query.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
Expand All @@ -203,8 +204,7 @@ def forward(
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
kv_scale)
self.kv_cache_dtype, kv_scale)

num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
Expand Down Expand Up @@ -262,7 +262,7 @@ def forward(
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
Expand Down
19 changes: 17 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from vllm.attention.backends.abstract import (AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig


class Attention(nn.Module):
Expand All @@ -29,10 +30,24 @@ def __init__(
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__()
self.backend = get_attn_backend(torch.get_default_dtype())
impl_cls = self.backend.get_impl_cls()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if num_kv_heads is None:
num_kv_heads = num_heads
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size)
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window)

Expand Down
Loading

0 comments on commit c944527

Please sign in to comment.