Skip to content

Commit

Permalink
feat: Add dbrx support (#1685)
Browse files Browse the repository at this point in the history
Close #1679
  • Loading branch information
OlivierDehaene committed Mar 29, 2024
1 parent 762dbf3 commit f04255c
Show file tree
Hide file tree
Showing 4 changed files with 1,180 additions and 0 deletions.
24 changes: 24 additions & 0 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.models.flash_phi import FlashPhi
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
from text_generation_server.models.flash_dbrx import FlashDbrx
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA

except ImportError as e:
Expand All @@ -86,6 +87,7 @@
__all__.append(IDEFICSSharded)
__all__.append(FlashMistral)
__all__.append(FlashMixtral)
__all__.append(FlashDbrx)
__all__.append(FlashPhi)
__all__.append(FlashQwen2)
__all__.append(FlashStarcoder2)
Expand Down Expand Up @@ -381,6 +383,28 @@ def get_model(
trust_remote_code=trust_remote_code,
)

if model_type == "dbrx":
if FLASH_ATTENTION:
return FlashDbrx(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
else:
return CausalLM(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
if sharded:
if FLASH_ATTENTION:
Expand Down
Loading

0 comments on commit f04255c

Please sign in to comment.