-
Notifications
You must be signed in to change notification settings - Fork 222
feat: Remote text embeddings compute #4064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
3743384
6c1127c
3061000
325e49f
3535c87
96ba58b
98a7fc6
798c9fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,12 +2,17 @@ | |||||||||||||||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||||||||||||||||
| import sys | ||||||||||||||||||||||||||||||||||||||||||||||||
| import threading | ||||||||||||||||||||||||||||||||||||||||||||||||
| from concurrent.futures import ProcessPoolExecutor, as_completed | ||||||||||||||||||||||||||||||||||||||||||||||||
| from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed | ||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import List | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| import huggingface_hub # noqa: F401 | ||||||||||||||||||||||||||||||||||||||||||||||||
| import litellm | ||||||||||||||||||||||||||||||||||||||||||||||||
| import mgp | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # Suppress LiteLLM's "Provider List: ..." banner that fires on every | ||||||||||||||||||||||||||||||||||||||||||||||||
| # get_llm_provider() miss — we probe deliberately for local names. | ||||||||||||||||||||||||||||||||||||||||||||||||
| litellm.suppress_debug_info = True | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # We need to import huggingface_hub, otherwise sentence_transformers will fail to load the model. | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| sys.path.append(os.path.join(os.path.dirname(__file__), "embed_worker")) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -200,7 +205,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||
| batch_size: int = 2000, | ||||||||||||||||||||||||||||||||||||||||||||||||
| return_embeddings: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||
| dimension: int = None, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> mgp.Record(success=bool, embeddings=mgp.Nullable[mgp.List[list]], dimension=int): | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> mgp.Record(success=bool, embeddings=mgp.Nullable[mgp.List[list]], dimension=mgp.Nullable[int]): | ||||||||||||||||||||||||||||||||||||||||||||||||
| model = _get_or_load_model(model_name, "cpu") | ||||||||||||||||||||||||||||||||||||||||||||||||
| vertex_input = isinstance(embedding_property, str) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if vertex_input: | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -242,7 +247,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||
| device: int = 0, | ||||||||||||||||||||||||||||||||||||||||||||||||
| return_embeddings: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||
| dimension: int = None, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> mgp.Record(success=bool, embeddings=mgp.Nullable[mgp.List[list]], dimension=int): | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> mgp.Record(success=bool, embeddings=mgp.Nullable[mgp.List[list]], dimension=mgp.Nullable[int]): | ||||||||||||||||||||||||||||||||||||||||||||||||
| vertex_input = isinstance(embedding_property, str) | ||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||
| model = _get_or_load_model(model_name, f"cuda:{device}") | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -305,7 +310,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||
| gpus: List[int] = [0], | ||||||||||||||||||||||||||||||||||||||||||||||||
| return_embeddings: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||
| dimension: int = None, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> mgp.Record(success=bool, embeddings=mgp.Nullable[mgp.List[list]], dimension=int): | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> mgp.Record(success=bool, embeddings=mgp.Nullable[mgp.List[list]], dimension=mgp.Nullable[int]): | ||||||||||||||||||||||||||||||||||||||||||||||||
| vertex_input = isinstance(embedding_property, str) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -407,6 +412,116 @@ | |||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def resolve_remote_model(model_name): | ||||||||||||||||||||||||||||||||||||||||||||||||
| """Ask LiteLLM whether ``model_name`` belongs to one of its providers. | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| Returns ``(model, provider, default_api_base)`` if recognized (e.g. | ||||||||||||||||||||||||||||||||||||||||||||||||
| ``"openai/text-embedding-3-small"``), otherwise ``None`` — in which case | ||||||||||||||||||||||||||||||||||||||||||||||||
| the caller should fall through to the local SentenceTransformer path. | ||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||
| if not isinstance(model_name, str) or not model_name: | ||||||||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||
| model, provider, _dynamic_key, default_api_base = litellm.get_llm_provider(model_name) | ||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||||||||
| return model, provider, default_api_base | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def default_remote_batch_size(provider): | ||||||||||||||||||||||||||||||||||||||||||||||||
| """Default `remote_batch_size` when the user doesn't set one. | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| Only lists providers with a *documented hard cap* — exceeding those | ||||||||||||||||||||||||||||||||||||||||||||||||
| returns HTTP 400. Everything else falls through to a generic default. | ||||||||||||||||||||||||||||||||||||||||||||||||
| Users override per-call via `remote_batch_size`. | ||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||
| return { | ||||||||||||||||||||||||||||||||||||||||||||||||
| "voyage": 1000, # docs.voyageai.com/docs/embeddings | ||||||||||||||||||||||||||||||||||||||||||||||||
| "cohere": 96, # docs.cohere.com/reference/embed | ||||||||||||||||||||||||||||||||||||||||||||||||
| "openai": 2048, # 2048 items + 300K tokens/request; token limit may bite first | ||||||||||||||||||||||||||||||||||||||||||||||||
| "azure": 2048, # mirrors OpenAI | ||||||||||||||||||||||||||||||||||||||||||||||||
| }.get(provider, 256) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def l2_normalize(vec): | ||||||||||||||||||||||||||||||||||||||||||||||||
| s = sum(x * x for x in vec) ** 0.5 | ||||||||||||||||||||||||||||||||||||||||||||||||
| if s == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||
| return vec | ||||||||||||||||||||||||||||||||||||||||||||||||
| return [x / s for x in vec] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def remote_compute( | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Check failure on line 453 in mage/python/embeddings.py
|
||||||||||||||||||||||||||||||||||||||||||||||||
| input_items: mgp.Any, | ||||||||||||||||||||||||||||||||||||||||||||||||
| cfg: mgp.Map, | ||||||||||||||||||||||||||||||||||||||||||||||||
| dimension: int, | ||||||||||||||||||||||||||||||||||||||||||||||||
| resolved, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> mgp.Record(success=bool, embeddings=mgp.Nullable[mgp.List[list]], dimension=mgp.Nullable[int]): | ||||||||||||||||||||||||||||||||||||||||||||||||
| """LiteLLM-routed remote embedding path. | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| Fans chunks across a small thread pool, preserves input order, and | ||||||||||||||||||||||||||||||||||||||||||||||||
| L2-normalizes client-side when ``cfg["normalize"]`` (default) so behavior | ||||||||||||||||||||||||||||||||||||||||||||||||
| matches the local sentence_transformers path. Raises on permanent failure | ||||||||||||||||||||||||||||||||||||||||||||||||
| — the caller catches and converts to ``success=False``. | ||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| _model, provider, default_api_base = resolved | ||||||||||||||||||||||||||||||||||||||||||||||||
| vertex_input = isinstance(cfg["embedding_property"], str) | ||||||||||||||||||||||||||||||||||||||||||||||||
| texts = build_texts(input_items, cfg["excluded_properties"]) if vertex_input else list(input_items) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| n = len(texts) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if n == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||
| return return_data( | ||||||||||||||||||||||||||||||||||||||||||||||||
| input_items if vertex_input else [], | ||||||||||||||||||||||||||||||||||||||||||||||||
| embedding_property_name=cfg["embedding_property"] if vertex_input else None, | ||||||||||||||||||||||||||||||||||||||||||||||||
| return_embeddings=cfg["return_embeddings"], | ||||||||||||||||||||||||||||||||||||||||||||||||
| success=True, | ||||||||||||||||||||||||||||||||||||||||||||||||
| dimension=dimension, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| chunk_size = cfg["remote_batch_size"] or default_remote_batch_size(provider) | ||||||||||||||||||||||||||||||||||||||||||||||||
| chunks = [texts[i : i + chunk_size] for i in range(0, n, chunk_size)] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| api_base = cfg["api_base"] or default_api_base | ||||||||||||||||||||||||||||||||||||||||||||||||
| base_kwargs = { | ||||||||||||||||||||||||||||||||||||||||||||||||
| "model": cfg["model_name"], | ||||||||||||||||||||||||||||||||||||||||||||||||
| "num_retries": cfg["num_retries"], | ||||||||||||||||||||||||||||||||||||||||||||||||
| "timeout": cfg["timeout"], | ||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||
| if api_base: | ||||||||||||||||||||||||||||||||||||||||||||||||
| base_kwargs["api_base"] = api_base | ||||||||||||||||||||||||||||||||||||||||||||||||
| if cfg["input_type"]: | ||||||||||||||||||||||||||||||||||||||||||||||||
| base_kwargs["input_type"] = cfg["input_type"] | ||||||||||||||||||||||||||||||||||||||||||||||||
| if cfg["dimensions"]: | ||||||||||||||||||||||||||||||||||||||||||||||||
| base_kwargs["dimensions"] = cfg["dimensions"] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def _call(chunk_texts): | ||||||||||||||||||||||||||||||||||||||||||||||||
| resp = litellm.embedding(input=chunk_texts, **base_kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return [d["embedding"] for d in resp["data"]] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+499
to
+500
|
||||||||||||||||||||||||||||||||||||||||||||||||
| return [d["embedding"] for d in resp["data"]] | |
| data = resp["data"] | |
| if len(data) != len(chunk_texts): | |
| raise ValueError( | |
| f"Embedding provider returned {len(data)} embeddings for {len(chunk_texts)} inputs." | |
| ) | |
| def _item_index(item): | |
| if isinstance(item, dict): | |
| return item.get("index") | |
| return getattr(item, "index", None) | |
| if data and all(_item_index(item) is not None for item in data): | |
| data = sorted(data, key=_item_index) | |
| return [d["embedding"] for d in data] |
Copilot
AI
Apr 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ThreadPoolExecutor(max_workers=max(1, cfg["concurrency"])) will raise a TypeError if concurrency is provided as a non-int (e.g., string/float from Cypher). It also allows extreme values that could create too many threads inside the query module.
Consider validating/coercing concurrency in validate_configuration() to an int within a safe range (e.g., >=1 and <= some cap), and fail fast with a clear error/log message when invalid.
Copilot
AI
Apr 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If any chunk future raises, fut.result() will propagate immediately, but the with ThreadPoolExecutor(...) context will still wait for all other in-flight futures to finish on exit. This can significantly increase end-to-end latency on partial failure (e.g., one chunk timing out) even though the call is already considered failed.
Consider catching exceptions per-future, cancelling remaining futures (when possible), and re-raising after shutdown with cancel_futures=True (or equivalent) so failures return promptly.
| with ThreadPoolExecutor(max_workers=max(1, cfg["concurrency"])) as ex: | |
| fut2idx = {ex.submit(_call, c): i for i, c in enumerate(chunks)} | |
| for fut in as_completed(fut2idx): | |
| results[fut2idx[fut]] = fut.result() | |
| ex = ThreadPoolExecutor(max_workers=max(1, cfg["concurrency"])) | |
| fut2idx = {} | |
| failed = False | |
| try: | |
| fut2idx = {ex.submit(_call, c): i for i, c in enumerate(chunks)} | |
| for fut in as_completed(fut2idx): | |
| try: | |
| results[fut2idx[fut]] = fut.result() | |
| except Exception: | |
| failed = True | |
| for other in fut2idx: | |
| if other is not fut: | |
| other.cancel() | |
| raise | |
| finally: | |
| if failed: | |
| ex.shutdown(wait=False, cancel_futures=True) | |
| else: | |
| ex.shutdown(wait=True) |
Copilot
AI
Apr 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_redacted_config() only redacts credentials in the user:pass@host URL form. Secrets can still leak into logs if api_base contains tokens in the query string or path (e.g., ...?token=...), and this function currently returns the original config unchanged in those cases.
Consider parsing api_base with urllib.parse and redacting userinfo plus query/fragment (or whitelisting to scheme/host/port only) before logging.
Check failure on line 621 in mage/python/embeddings.py
SonarQubeCloud / SonarCloud Code Analysis
Refactor this function to reduce its Cognitive Complexity from 19 to the 15 allowed.
See more on https://sonarcloud.io/project/issues?id=memgraph_memgraph&issues=AZ22DjuAseNgnG04vSIZ&open=AZ22DjuAseNgnG04vSIZ&pullRequest=4064
Check failure on line 641 in mage/python/embeddings.py
SonarQubeCloud / SonarCloud Code Analysis
Change this argument; Function "return_data" expects a different type
See more on https://sonarcloud.io/project/issues?id=memgraph_memgraph&issues=AZ22DjuAseNgnG04vSIa&open=AZ22DjuAseNgnG04vSIa&pullRequest=4064
Check failure on line 646 in mage/python/embeddings.py
SonarQubeCloud / SonarCloud Code Analysis
Change this argument; Function "remote_compute" expects a different type
See more on https://sonarcloud.io/project/issues?id=memgraph_memgraph&issues=AZ22DjuAseNgnG04vSIb&open=AZ22DjuAseNgnG04vSIb&pullRequest=4064
Check failure on line 654 in mage/python/embeddings.py
SonarQubeCloud / SonarCloud Code Analysis
Change this argument; Function "return_data" expects a different type
See more on https://sonarcloud.io/project/issues?id=memgraph_memgraph&issues=AZ22DjuAseNgnG04vSIc&open=AZ22DjuAseNgnG04vSIc&pullRequest=4064
Copilot
AI
Apr 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_remote_info_cache is an unbounded global dict keyed by (model_name, api_base). In long-running Memgraph processes with many unique model/api_base values, this can grow without limit.
Consider adding a simple bound/eviction policy (e.g., LRU with a max size) or a TTL to avoid unbounded memory growth.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| CREATE (:Doc {title: "remote routing failure fixture"}); |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| query: > | ||
| MATCH (n:Doc) | ||
| WITH n, { | ||
| model_name: "ollama/nomic-embed-text", | ||
| api_base: "http://127.0.0.1:1", | ||
| num_retries: 0, | ||
| timeout: 2 | ||
| } AS configuration | ||
| CALL embeddings.node_sentence([n], configuration) | ||
| YIELD success | ||
| RETURN success, n.embedding IS NULL AS no_embedding_written; | ||
|
|
||
| output: | ||
| - success: false | ||
| no_embedding_written: true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remote_batch_sizeis used as therange()step when chunking (range(0, n, chunk_size)). If a user passesremote_batch_size <= 0(or a non-int), this will either raise (step=0) or silently produce no chunks (negative step), returningsuccess=Truewithout making any remote call.Validate and coerce
remote_batch_size(and possiblytimeout/num_retries) invalidate_configuration()to ensure it’s an integer >= 1 when set.