Skip to content

Commit

Permalink
Don't download both safetensor and bin files. (vllm-project#2480)
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolaBorisov authored and hongxiayang committed Jan 18, 2024
1 parent 3d96adf commit d246a38
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions vllm/model_executor/weight_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Utilities for downloading and initializing model weights."""
import filelock
import glob
import fnmatch
import json
import os
from collections import defaultdict
from typing import Any, Iterator, List, Optional, Tuple

from huggingface_hub import snapshot_download
from huggingface_hub import snapshot_download, HfFileSystem
import numpy as np
from safetensors.torch import load_file, save_file, safe_open
import torch
Expand Down Expand Up @@ -149,6 +150,20 @@ def prepare_hf_model_weights(
allow_patterns += ["*.pt"]

if not is_local:
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)

# depending on what is available we download different things
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
if pattern == "*.safetensors":
use_safetensors = True
break

logger.info(f"Downloading model weights {allow_patterns}")
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
Expand All @@ -163,8 +178,6 @@ def prepare_hf_model_weights(
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if not use_safetensors:
# Exclude files that are not needed for inference.
Expand Down

0 comments on commit d246a38

Please sign in to comment.