Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Jul 18, 2023
1 parent bc2f351 commit d186b13
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 29 deletions.
4 changes: 2 additions & 2 deletions server/Makefile-flash-att-v2
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
flash_att_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc
flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc

flash-attention-v2:
# Clone flash attention
pip install packaging
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2

build-flash-attention-v2: flash-attention-v2
cd flash-attention-v2 && git fetch && git checkout $(flash_att_commit)
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit)
cd flash-attention-v2 && python setup.py build

install-flash-attention-v2: build-flash-attention-v2
Expand Down
6 changes: 2 additions & 4 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,8 @@
FlashSantacoderSharded,
)

except ImportError:
logger.opt(exception=True).warning(
"Could not import Flash Attention enabled models"
)
except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}")
FLASH_ATTENTION = False

if FLASH_ATTENTION:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def forward(
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale
self.softmax_scale,
)
# Decode
else:
Expand Down Expand Up @@ -308,7 +308,7 @@ def forward(
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale
self.softmax_scale,
)
# Decode
else:
Expand Down
53 changes: 32 additions & 21 deletions server/text_generation_server/utils/flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import torch

from loguru import logger

if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")

Expand All @@ -18,10 +20,11 @@
try:
import flash_attn_2_cuda
except ImportError:
raise ImportError("Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention v2 with `cd server && make install install-flash-attention-v2`"
)
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention v2 with `cd server && make install install-flash-attention-v2`"
)
if not (is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for "
Expand All @@ -32,26 +35,28 @@
try:
import flash_attn_cuda
except ImportError:
raise ImportError("Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e

if not (is_sm75 or is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
logger.warning(f"Unable to use Flash Attention V2: {e}")
HAS_FLASH_ATTN = True


def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
):
if HAS_FLASH_ATTN_V2:
return flash_attn_2_cuda.varlen_fwd(
Expand All @@ -76,21 +81,27 @@ def attention(
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1, -1)
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = k.unsqueeze(2).expand(-1, -1, q.shape[1], -1, -1) \
.reshape(original_shape[0], -1, original_shape[1], original_shape[2])
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1, -1)
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = v.unsqueeze(2).expand(-1, -1, q.shape[1], -1, -1) \
.reshape(original_shape[0], -1, original_shape[1], original_shape[2])
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)

return flash_attn_cuda.fwd(
q,
Expand Down

0 comments on commit d186b13

Please sign in to comment.