Skip to content

Commit

Permalink
Fix model download for ONNX embedder (#976)
Browse files Browse the repository at this point in the history
## Description of changes
The current function is looking for the tar.gz file instead of checking
if the folder already exists, so if the tar.gz gets deleted after
extraction, it downloads it again.. This PR resolves this and checks for
the model in the extracted folder before attempting to download or
extract again.

## Test plan
By using it

## Documentation Changes
I didn't find any documentation about how this does the download.
  • Loading branch information
Josh-XT committed Aug 15, 2023
1 parent c26667f commit f1af776
Showing 1 changed file with 32 additions and 10 deletions.
42 changes: 32 additions & 10 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def __init__(self) -> None:

# Borrowed from https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51
# Download with tqdm to preserve the sentence-transformers experience
def _download(self, url: str, fname: Path, chunk_size: int = 1024) -> None:
def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None:
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
with open(fname, "wb") as file, self.tqdm(
Expand Down Expand Up @@ -326,14 +326,18 @@ def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray:
def _init_model_and_tokenizer(self) -> None:
if self.model is None and self.tokenizer is None:
self.tokenizer = self.Tokenizer.from_file(
str(self.DOWNLOAD_PATH / self.EXTRACTED_FOLDER_NAME / "tokenizer.json")
os.path.join(
self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "tokenizer.json"
)
)
# max_seq_length = 256, for some reason sentence-transformers uses 256 even though the HF config has a max length of 128
# https://github.com/UKPLab/sentence-transformers/blob/3e1929fddef16df94f8bc6e3b10598a98f46e62d/docs/_static/html/models_en_sentence_embeddings.html#LL480
self.tokenizer.enable_truncation(max_length=256)
self.tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=256)
self.model = self.ort.InferenceSession(
str(self.DOWNLOAD_PATH / self.EXTRACTED_FOLDER_NAME / "model.onnx")
os.path.join(
self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "model.onnx"
)
)

def __call__(self, texts: Documents) -> Embeddings:
Expand All @@ -344,16 +348,35 @@ def __call__(self, texts: Documents) -> Embeddings:
return res

def _download_model_if_not_exists(self) -> None:
onnx_files = [
"config.json",
"model.onnx",
"special_tokens_map.json",
"tokenizer_config.json",
"tokenizer.json",
"vocab.txt",
]
extracted_folder = os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME)
onnx_files_exist = True
for f in onnx_files:
if not os.path.exists(os.path.join(extracted_folder, f)):
onnx_files_exist = False
break
# Model is not downloaded yet
if not os.path.exists(self.DOWNLOAD_PATH / self.ARCHIVE_FILENAME):
if not onnx_files_exist:
os.makedirs(self.DOWNLOAD_PATH, exist_ok=True)
self._download(
self.MODEL_DOWNLOAD_URL, self.DOWNLOAD_PATH / self.ARCHIVE_FILENAME
)
if not os.path.exists(
os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME)
):
self._download(
url=self.MODEL_DOWNLOAD_URL,
fname=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME),
)
with tarfile.open(
self.DOWNLOAD_PATH / self.ARCHIVE_FILENAME, "r:gz"
name=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME),
mode="r:gz",
) as tar:
tar.extractall(self.DOWNLOAD_PATH)
tar.extractall(path=self.DOWNLOAD_PATH)


def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction]:
Expand Down Expand Up @@ -410,7 +433,6 @@ def __init__(
self._session.headers.update({"Authorization": f"Bearer {api_key}"})

def __call__(self, texts: Documents) -> Embeddings:

embeddings = []
for text in texts:
response = self._session.post(
Expand Down

0 comments on commit f1af776

Please sign in to comment.