Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(server): flash attention v2 #624

Merged
merged 6 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ COPY server/Makefile-flash-att Makefile
# Build specific version of flash attention
RUN make build-flash-attention

# Build Flash Attention v2 CUDA kernels
FROM kernel-builder as flash-att-v2-builder

WORKDIR /usr/src

COPY server/Makefile-flash-att-v2 Makefile

# Build specific version of flash attention v2
RUN make build-flash-attention-v2

# Build Transformers CUDA kernels
FROM kernel-builder as custom-kernels-builder

Expand Down Expand Up @@ -146,8 +156,11 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages

# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages

# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages

# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
Expand Down
1 change: 1 addition & 0 deletions server/Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
include Makefile-flash-att
include Makefile-flash-att-v2
include Makefile-vllm

unit-tests:
Expand Down
13 changes: 13 additions & 0 deletions server/Makefile-flash-att-v2
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc

flash-attention-v2:
# Clone flash attention
pip install packaging
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2

build-flash-attention-v2: flash-attention-v2
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit)
cd flash-attention-v2 && python setup.py build

install-flash-attention-v2: build-flash-attention-v2
cd flash-attention-v2 && python setup.py install
54 changes: 12 additions & 42 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,51 +42,21 @@
"get_model",
]

FLASH_ATT_ERROR_MESSAGE = (
"{} requires CUDA and Flash Attention kernels to be installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
)
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."

FLASH_ATTENTION = True
try:
if not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
if not torch.cuda.is_available():
FLASH_ATT_ERROR_MESSAGE = (
"{} requires CUDA. No compatible CUDA devices found."
)
raise ImportError("CUDA is not available")

major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0

supported = is_sm75 or is_sm8x or is_sm90
if not supported:
FLASH_ATT_ERROR_MESSAGE = (
"{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. "
"No compatible CUDA device found."
)
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
)

from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_neox import FlashNeoXSharded
from text_generation_server.models.flash_llama import (
FlashLlama,
)
from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded,
)

FLASH_ATTENTION = True
else:
FLASH_ATTENTION = False
except ImportError:
logger.opt(exception=True).warning(
"Could not import Flash Attention enabled models"
from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_neox import FlashNeoXSharded
from text_generation_server.models.flash_llama import (
FlashLlama,
)
from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded,
)

except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}")
FLASH_ATTENTION = False

if FLASH_ATTENTION:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
from typing import Optional, List, Tuple

# Flash attention imports
import flash_attn_cuda
import dropout_layer_norm

# vllm imports
import vllm_cache_ops
import vllm_attention_ops

from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
Expand Down Expand Up @@ -164,22 +164,14 @@ def forward(
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn_cuda.fwd(
attention(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attn_output,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,11 @@
from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional, List, Tuple

# Flash attention imports
import flash_attn_cuda

# vllm imports
import vllm_cache_ops
import vllm_attention_ops

from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
Expand Down Expand Up @@ -153,22 +151,14 @@ def forward(
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn_cuda.fwd(
attention(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attn_output,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple

# Flash attention imports
import flash_attn_cuda

# vllm imports
import vllm_cache_ops
import vllm_attention_ops

from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
Expand Down Expand Up @@ -182,27 +180,15 @@ def forward(

# Prefill
if cu_seqlen_prefill is not None:
if self.num_heads_kv == 1:
# Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size)

# flash attention
flash_attn_cuda.fwd(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:
Expand Down Expand Up @@ -314,30 +300,15 @@ def forward(

# Prefill
if cu_seqlen_prefill is not None:
# Expand to query shape
kv = (
kv.unsqueeze(2)
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
)

# flash attention
flash_attn_cuda.fwd(
attention(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
attn_output,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple

# Flash attention imports
import flash_attn_cuda

# vllm imports
import vllm_cache_ops
import vllm_attention_ops

from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
Expand Down Expand Up @@ -271,26 +269,15 @@ def forward(

# Prefill
if cu_seqlen_prefill is not None:
# Expand from 1 to num_heads
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)

# flash attention
flash_attn_cuda.fwd(
attention(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:
Expand Down
Loading
Loading