Skip to content

Commit

Permalink
chore: update vllm to use gptq quanitzed model (#378)
Browse files Browse the repository at this point in the history
* chore: update vllm to use gptq quanitzed model

* bug: fix catch-all wildcard for e2e workflow
  • Loading branch information
YrrepNoj committed Apr 10, 2024
1 parent 1a6f29e commit dc1029d
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/e2e.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ on:
pull_request:
paths:
# Catch-all
- "*"
- "**"

# Ignore updates to the .github directory, unless it's this current file
- "!.github/**"
Expand Down
6 changes: 3 additions & 3 deletions packages/vllm/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ COPY build/leapfrogai_api*.whl leapfrogai_api-100.100.100-py3-none-any.whl
RUN pip install "leapfrogai_api-100.100.100-py3-none-any.whl[vllm]" --no-index --find-links=build/

# download model
ARG REPO_ID=TheBloke/Synthia-7B-v2.0-AWQ
ARG REVISION=main
ARG REPO_ID=TheBloke/Synthia-7B-v2.0-GPTQ
ARG REVISION=gptq-4bit-32g-actorder_True
ENV HF_HOME=/home/leapfrogai/.cache/huggingface
COPY scripts/model_download.py scripts/model_download.py

RUN REPO_ID=${REPO_ID} FILENAME=${FILENAME} REVISION=${REVISION} python3.11 scripts/model_download.py

ENV QUANTIZATION=awq
ENV QUANTIZATION=gptq

COPY main.py .
COPY config.yaml .
Expand Down
3 changes: 3 additions & 0 deletions packages/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def __init__(self):
quantization=os.environ["QUANTIZATION"] or None,
max_context_len_to_capture=self.backend_config.max_context_length,
worker_use_ray=True,
max_model_len=self.backend_config.max_context_length,
dtype="auto",
gpu_memory_utilization=0.90,
)
self.engine = AsyncLLMEngine.from_engine_args(self.engine_args)

Expand Down
7 changes: 4 additions & 3 deletions packages/vllm/scripts/model_download.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from huggingface_hub import snapshot_download
import os

REPO_ID = os.environ.get("REPO_ID", "TheBloke/Synthia-7B-v2.0-AWQ")
REVISION = os.environ.get("REVISION", "main")
from huggingface_hub import snapshot_download

REPO_ID = os.environ.get("REPO_ID", "TheBloke/Synthia-7B-v2.0-GPTQ")
REVISION = os.environ.get("REVISION", "gptq-4bit-32g-actorder_True")

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

Expand Down

0 comments on commit dc1029d

Please sign in to comment.