Skip to content

Commit

Permalink
feat(server): flash attention v2 (#624)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Jul 18, 2023
1 parent 4d38a1c commit 3b71c38
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 112 deletions.
15 changes: 14 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ COPY server/Makefile-flash-att Makefile
# Build specific version of flash attention
RUN make build-flash-attention

# Build Flash Attention v2 CUDA kernels
FROM kernel-builder as flash-att-v2-builder

WORKDIR /usr/src

COPY server/Makefile-flash-att-v2 Makefile

# Build specific version of flash attention v2
RUN make build-flash-attention-v2

# Build Transformers CUDA kernels
FROM kernel-builder as custom-kernels-builder

Expand Down Expand Up @@ -146,8 +156,11 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages

# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages

# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages

# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
Expand Down
1 change: 1 addition & 0 deletions server/Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
include Makefile-flash-att
include Makefile-flash-att-v2
include Makefile-vllm

unit-tests:
Expand Down
13 changes: 13 additions & 0 deletions server/Makefile-flash-att-v2
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc

flash-attention-v2:
# Clone flash attention
pip install packaging
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2

build-flash-attention-v2: flash-attention-v2
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit)
cd flash-attention-v2 && python setup.py build

install-flash-attention-v2: build-flash-attention-v2
cd flash-attention-v2 && python setup.py install
54 changes: 12 additions & 42 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,51 +42,21 @@
"get_model",
]

FLASH_ATT_ERROR_MESSAGE = (
"{} requires CUDA and Flash Attention kernels to be installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
)
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."

FLASH_ATTENTION = True
try:
if not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
if not torch.cuda.is_available():
FLASH_ATT_ERROR_MESSAGE = (
"{} requires CUDA. No compatible CUDA devices found."
)
raise ImportError("CUDA is not available")

major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0

supported = is_sm75 or is_sm8x or is_sm90
if not supported:
FLASH_ATT_ERROR_MESSAGE = (
"{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. "
"No compatible CUDA device found."
)
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
)

from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_neox import FlashNeoXSharded
from text_generation_server.models.flash_llama import (
FlashLlama,
)
from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded,
)

FLASH_ATTENTION = True
else:
FLASH_ATTENTION = False
except ImportError:
logger.opt(exception=True).warning(
"Could not import Flash Attention enabled models"
from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_neox import FlashNeoXSharded
from text_generation_server.models.flash_llama import (
FlashLlama,
)
from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded,
)

except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}")
FLASH_ATTENTION = False

if FLASH_ATTENTION:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
from typing import Optional, List, Tuple

# Flash attention imports
import flash_attn_cuda
import dropout_layer_norm

# vllm imports
import vllm_cache_ops
import vllm_attention_ops

from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
Expand Down Expand Up @@ -164,22 +164,14 @@ def forward(
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn_cuda.fwd(
attention(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attn_output,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,11 @@
from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional, List, Tuple

# Flash attention imports
import flash_attn_cuda

# vllm imports
import vllm_cache_ops
import vllm_attention_ops

from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
Expand Down Expand Up @@ -153,22 +151,14 @@ def forward(
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn_cuda.fwd(
attention(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attn_output,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple

# Flash attention imports
import flash_attn_cuda

# vllm imports
import vllm_cache_ops
import vllm_attention_ops

from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
Expand Down Expand Up @@ -182,27 +180,15 @@ def forward(

# Prefill
if cu_seqlen_prefill is not None:
if self.num_heads_kv == 1:
# Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size)

# flash attention
flash_attn_cuda.fwd(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:
Expand Down Expand Up @@ -314,30 +300,15 @@ def forward(

# Prefill
if cu_seqlen_prefill is not None:
# Expand to query shape
kv = (
kv.unsqueeze(2)
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
)

# flash attention
flash_attn_cuda.fwd(
attention(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
attn_output,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple

# Flash attention imports
import flash_attn_cuda

# vllm imports
import vllm_cache_ops
import vllm_attention_ops

from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
Expand Down Expand Up @@ -271,26 +269,15 @@ def forward(

# Prefill
if cu_seqlen_prefill is not None:
# Expand from 1 to num_heads
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)

# flash attention
flash_attn_cuda.fwd(
attention(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:
Expand Down
Loading

0 comments on commit 3b71c38

Please sign in to comment.