Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mostlyai/qa/_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,14 +291,16 @@ def calculate_embeddings(
progress_from: int | None = None,
progress_to: int | None = None,
) -> np.ndarray:
t0 = time.time()
# load embedder
t0 = time.time()
embedder = load_embedder()
_LOG.info(f"loaded load_embedder in {time.time() - t0:.2f}s")
# split into buckets for calculating embeddings to avoid memory issues and report continuous progress
steps = progress_to - progress_from if progress_to is not None and progress_from is not None else 1
buckets = np.array_split(strings, steps)
buckets = [b for b in buckets if len(b) > 0]
# calculate embeddings for each bucket
t0 = time.time()
embeds = []
for i, bucket in enumerate(buckets, 1):
embeds += [embedder.encode(bucket.tolist(), show_progress_bar=False)]
Expand Down
4 changes: 2 additions & 2 deletions mostlyai/qa/assets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ def load_tokenizer():

def load_embedder():
"""
Load the embedder model.
Load the embedder model.
Can deal with read-only cache folder by attempting to download the model if it is not locally available.
Users can set MOSTLY_HF_HOME environment variable to override the default cache folder.
"""
from sentence_transformers import SentenceTransformer

model_name = "sentence-transformers/all-MiniLM-L6-v2"
cache_folder=os.getenv("MOSTLY_HF_HOME")
cache_folder = os.getenv("MOSTLY_HF_HOME")
try:
# First try loading from local cache
return SentenceTransformer(model_name_or_path=model_name, cache_folder=cache_folder, local_files_only=True)
Expand Down