Skip to content

Commit

Permalink
Enhance Model Delivery
Browse files Browse the repository at this point in the history
This PR introduces a few enhancements:
- Allow to override temporary path via environment variable `MLC_TEMP_DIR`;
- Add a 10-time retry when uploading the quantized weights to
  HuggingFace Hub. It could fail at times;
- Echo the commands being used to quantize the models in `logs.txt`;
- Fix a compatibility issue when pulling individual weights down from
  HuggingFace Hub in Git LFS.
  • Loading branch information
junrushao committed Nov 17, 2023
1 parent 2600b9a commit c0d9cce
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 46 deletions.
90 changes: 47 additions & 43 deletions python/mlc_chat/cli/delivery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses
import json
import logging
import os
import shutil
import subprocess
import tempfile
Expand All @@ -24,6 +25,7 @@
)

logger = logging.getLogger(__name__)
MLC_TEMP_DIR = os.getenv("MLC_TEMP_DIR", None)


@dataclasses.dataclass
Expand Down Expand Up @@ -58,7 +60,7 @@ def __exit__(self, exc_type, exc_value, traceback):

def create_temp_dir(self) -> Path:
"""Create a temporary directory that will be deleted when exiting the scope."""
temp_dir = tempfile.mkdtemp()
temp_dir = tempfile.mkdtemp(dir=MLC_TEMP_DIR)
self.add(lambda: shutil.rmtree(temp_dir, ignore_errors=True))
return Path(temp_dir)

Expand Down Expand Up @@ -94,47 +96,41 @@ def _run_quantization(
api.create_repo(repo_id=repo, private=False)
logger.info("[HF] Repo recreated")
succeeded = True
with tempfile.TemporaryDirectory() as output_dir:
with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as output_dir:
log_path = Path(output_dir) / "logs.txt"
with log_path.open("a", encoding="utf-8") as log_file:
assert isinstance(model_info.model, Path)
logger.info("[MLC] Processing in directory: %s", output_dir)
subprocess.run(
[
"mlc_chat",
"gen_mlc_chat_config",
"--model",
str(model_info.model),
"--quantization",
model_info.quantization,
"--conv-template",
model_info.conv_template,
"--context-window-size",
str(model_info.context_window_size),
"--output",
output_dir,
],
check=True,
stdout=log_file,
stderr=subprocess.STDOUT,
)
subprocess.run(
[
"mlc_chat",
"convert_weight",
"--model",
str(model_info.model),
"--quantization",
model_info.quantization,
"--source-format",
model_info.source_format,
"--output",
output_dir,
],
check=False,
stdout=log_file,
stderr=subprocess.STDOUT,
)
cmd = [
"mlc_chat",
"gen_mlc_chat_config",
"--model",
str(model_info.model),
"--quantization",
model_info.quantization,
"--conv-template",
model_info.conv_template,
"--context-window-size",
str(model_info.context_window_size),
"--output",
output_dir,
]
print(" ".join(cmd), file=log_file, flush=True)
subprocess.run(cmd, check=True, stdout=log_file, stderr=subprocess.STDOUT)
cmd = [
"mlc_chat",
"convert_weight",
"--model",
str(model_info.model),
"--quantization",
model_info.quantization,
"--source-format",
model_info.source_format,
"--output",
output_dir,
]
print(" ".join(cmd), file=log_file, flush=True)
subprocess.run(cmd, check=False, stdout=log_file, stderr=subprocess.STDOUT)
logger.info("[MLC] Complete!")
if not (Path(output_dir) / "ndarray-cache.json").exists():
logger.error(
Expand All @@ -145,11 +141,19 @@ def _run_quantization(
)
succeeded = False
logger.info("[HF] Uploading to: https://huggingface.co/%s", repo)
api.upload_folder(
folder_path=output_dir,
repo_id=repo,
commit_message="Initial commit",
)
for _retry in range(10):
try:
api.upload_folder(
folder_path=output_dir,
repo_id=repo,
commit_message="Initial commit",
)
except Exception as exc: # pylint: disable=broad-except
logger.error("[%s] %s. Retrying...", red("FAILED"), exc)
else:
break
else:
raise RuntimeError("Failed to upload to HuggingFace Hub with 10 retries")
return succeeded


Expand Down
8 changes: 5 additions & 3 deletions python/mlc_chat/support/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

logger = logging.getLogger(__name__)

MLC_TEMP_DIR = os.getenv("MLC_TEMP_DIR", None)


def get_cache_dir() -> Path:
"""Return the path to the cache directory."""
Expand Down Expand Up @@ -58,7 +60,7 @@ def git_clone(url: str, destination: Path, ignore_lfs: bool) -> None:
command = ["git", "clone", url, repo_name]
_ensure_directory_not_exist(destination, force_redo=False)
try:
with tempfile.TemporaryDirectory() as tmp_dir:
with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir:
logger.info("[Git] Cloning %s to %s", url, destination)
subprocess.run(
command,
Expand Down Expand Up @@ -94,7 +96,7 @@ def git_lfs_pull(repo_dir: Path) -> None:
for file in tqdm.tqdm(filenames):
logger.info("[Git LFS] Downloading %s", file)
subprocess.check_output(
["git", "-C", str(repo_dir), "lfs", "pull", file],
["git", "-C", str(repo_dir), "lfs", "pull", "--include", file],
stderr=subprocess.STDOUT,
)

Expand Down Expand Up @@ -144,7 +146,7 @@ def download_mlc_weights( # pylint: disable=too-many-locals
except ValueError:
logger.info("Weights already downloaded: %s", git_dir)
return
with tempfile.TemporaryDirectory() as tmp_dir_prefix:
with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir_prefix:
tmp_dir = Path(tmp_dir_prefix) / "tmp"
git_url = git_url_template.format(user=user, repo=repo)
git_clone(git_url, tmp_dir, ignore_lfs=True)
Expand Down

0 comments on commit c0d9cce

Please sign in to comment.