Conversation
|
Tracking
Standard development
CI Testing Labels
Documentation checklist
|
There was a problem hiding this comment.
Pull request overview
Adds LiteLLM-backed remote embedding providers to the embeddings query module while preserving the existing local sentence-transformers execution path, plus CI-safe tests for remote failure behavior.
Changes:
- Route embedding computation to LiteLLM remote providers based on
model_nameprefixes; keep bare/HF-style names local. - Add remote-only configuration options (e.g.,
api_base,timeout,num_retries,remote_batch_size,concurrency,normalize) and remote model dimension probing with caching. - Add e2e tests for graceful remote failures and opt-in pytest e2e tests for real OpenAI/Ollama providers.
Reviewed changes
Copilot reviewed 7 out of 8 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| release/package/mgbuild.sh | Ensures build deps are installed in the container before building the gssapi wheel. |
| mage/python/embeddings.py | Implements LiteLLM remote routing, batching/concurrency, client-side normalization, nullable dimension, and remote model-info probing/cache. |
| mage/tests/e2e/embeddings_test/test_remote_text_failure/test.yml | CI-safe e2e test verifying remote embeddings.text failure returns success=false with null outputs. |
| mage/tests/e2e/embeddings_test/test_remote_node_sentence_failure/test.yml | CI-safe e2e test verifying remote embeddings.node_sentence failure doesn’t write embeddings. |
| mage/tests/e2e/embeddings_test/test_remote_node_sentence_failure/input.cyp | Fixture graph data for the remote node_sentence failure test. |
| mage/tests/e2e/embeddings_test/test_remote_openai.py | Opt-in real-provider e2e tests for OpenAI embeddings/model_info/node_sentence. |
| mage/tests/e2e/embeddings_test/test_remote_ollama.py | Opt-in real-provider e2e tests for Ollama embeddings/model_info/node_sentence. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -591,6 +780,44 @@ | |||
| return info | |||
|
|
|||
|
|
|||
| def _remote_model_info(configuration, resolved): | |||
| """Probe a remote provider once to learn its embedding dimension. | |||
|
|
|||
| Cached per (model_name, api_base) so subsequent calls don't re-hit the API. | |||
| """ | |||
| _model, _provider, default_api_base = resolved | |||
| api_base = configuration.get("api_base") or default_api_base | |||
| cache_key = (configuration["model_name"], api_base) | |||
There was a problem hiding this comment.
_remote_info_cache is an unbounded global dict keyed by (model_name, api_base). In long-running Memgraph processes with many unique model/api_base values, this can grow without limit.
Consider adding a simple bound/eviction policy (e.g., LRU with a max size) or a TTL to avoid unbounded memory growth.
| pytestmark = pytest.mark.skipif( | ||
| not os.environ.get("OPENAI_API_KEY"), | ||
| reason="OPENAI_API_KEY not set — skipping opt-in OpenAI e2e test", | ||
| ) |
There was a problem hiding this comment.
This opt-in test runs whenever OPENAI_API_KEY is present in the pytest environment. In some developer/CI setups that key may be set for unrelated reasons, which could accidentally trigger real external API calls.
Consider gating with a dedicated opt-in env var (similar to MAGE_E2E_OLLAMA=1) in addition to OPENAI_API_KEY, so it only runs when explicitly requested.
| with ThreadPoolExecutor(max_workers=max(1, cfg["concurrency"])) as ex: | ||
| fut2idx = {ex.submit(_call, c): i for i, c in enumerate(chunks)} | ||
| for fut in as_completed(fut2idx): | ||
| results[fut2idx[fut]] = fut.result() |
There was a problem hiding this comment.
If any chunk future raises, fut.result() will propagate immediately, but the with ThreadPoolExecutor(...) context will still wait for all other in-flight futures to finish on exit. This can significantly increase end-to-end latency on partial failure (e.g., one chunk timing out) even though the call is already considered failed.
Consider catching exceptions per-future, cancelling remaining futures (when possible), and re-raising after shutdown with cancel_futures=True (or equivalent) so failures return promptly.
| with ThreadPoolExecutor(max_workers=max(1, cfg["concurrency"])) as ex: | |
| fut2idx = {ex.submit(_call, c): i for i, c in enumerate(chunks)} | |
| for fut in as_completed(fut2idx): | |
| results[fut2idx[fut]] = fut.result() | |
| ex = ThreadPoolExecutor(max_workers=max(1, cfg["concurrency"])) | |
| fut2idx = {} | |
| failed = False | |
| try: | |
| fut2idx = {ex.submit(_call, c): i for i, c in enumerate(chunks)} | |
| for fut in as_completed(fut2idx): | |
| try: | |
| results[fut2idx[fut]] = fut.result() | |
| except Exception: | |
| failed = True | |
| for other in fut2idx: | |
| if other is not fut: | |
| other.cancel() | |
| raise | |
| finally: | |
| if failed: | |
| ex.shutdown(wait=False, cancel_futures=True) | |
| else: | |
| ex.shutdown(wait=True) |
| except Exception as e: | ||
| logger.error(f"Failed to compute embeddings: {e}") | ||
| return return_data( | ||
| input_items, | ||
| configuration["embedding_property"], | ||
| configuration["return_embeddings"], | ||
| False, | ||
| dimension=dimension, | ||
| ) |
There was a problem hiding this comment.
compute_embeddings() can raise UnboundLocalError in the outer except block because dimension is only assigned in the local-path branch. If an exception occurs before that assignment (e.g., during remote routing or config access), the error handler references an undefined dimension.
Initialize dimension = None before the try, or set dimension in all branches before any possible exception handler uses it.
| dimension=dimension, | ||
| ) | ||
|
|
||
| chunk_size = cfg["remote_batch_size"] or default_remote_batch_size(provider) |
There was a problem hiding this comment.
remote_batch_size is used as the range() step when chunking (range(0, n, chunk_size)). If a user passes remote_batch_size <= 0 (or a non-int), this will either raise (step=0) or silently produce no chunks (negative step), returning success=True without making any remote call.
Validate and coerce remote_batch_size (and possibly timeout/num_retries) in validate_configuration() to ensure it’s an integer >= 1 when set.
| chunk_size = cfg["remote_batch_size"] or default_remote_batch_size(provider) | |
| configured_chunk_size = cfg["remote_batch_size"] | |
| if configured_chunk_size is None: | |
| chunk_size = default_remote_batch_size(provider) | |
| else: | |
| try: | |
| chunk_size = int(configured_chunk_size) | |
| except (TypeError, ValueError): | |
| raise ValueError("remote_batch_size must be an integer >= 1 when set.") | |
| if chunk_size < 1: | |
| raise ValueError("remote_batch_size must be an integer >= 1 when set.") |
| results = [None] * len(chunks) | ||
| with ThreadPoolExecutor(max_workers=max(1, cfg["concurrency"])) as ex: | ||
| fut2idx = {ex.submit(_call, c): i for i, c in enumerate(chunks)} |
There was a problem hiding this comment.
ThreadPoolExecutor(max_workers=max(1, cfg["concurrency"])) will raise a TypeError if concurrency is provided as a non-int (e.g., string/float from Cypher). It also allows extreme values that could create too many threads inside the query module.
Consider validating/coercing concurrency in validate_configuration() to an int within a safe range (e.g., >=1 and <= some cap), and fail fast with a clear error/log message when invalid.
| return [d["embedding"] for d in resp["data"]] | ||
|
|
There was a problem hiding this comment.
The remote embedding response is assumed to have 1:1 correspondence with the input chunk and to be in the same order ([d["embedding"] for d in resp["data"]]). Some providers return an index field or may not guarantee ordering; partial responses would also silently misalign embeddings when writing back to vertices.
Add a sanity check that len(resp["data"]) == len(chunk_texts) and, when available, sort by d["index"] (or equivalent) to preserve correct alignment before flattening/writing results.
| return [d["embedding"] for d in resp["data"]] | |
| data = resp["data"] | |
| if len(data) != len(chunk_texts): | |
| raise ValueError( | |
| f"Embedding provider returned {len(data)} embeddings for {len(chunk_texts)} inputs." | |
| ) | |
| def _item_index(item): | |
| if isinstance(item, dict): | |
| return item.get("index") | |
| return getattr(item, "index", None) | |
| if data and all(_item_index(item) is not None for item in data): | |
| data = sorted(data, key=_item_index) | |
| return [d["embedding"] for d in data] |
| api_base = cfg.get("api_base") | ||
| if not isinstance(api_base, str) or "@" not in api_base: | ||
| return cfg | ||
| return {**cfg, "api_base": re.sub(r"(://)[^/@\s]+:[^/@\s]+@", r"\1<redacted>@", api_base)} |
There was a problem hiding this comment.
_redacted_config() only redacts credentials in the user:pass@host URL form. Secrets can still leak into logs if api_base contains tokens in the query string or path (e.g., ...?token=...), and this function currently returns the original config unchanged in those cases.
Consider parsing api_base with urllib.parse and redacting userinfo plus query/fragment (or whitelisting to scheme/host/port only) before logging.


Add remote embedding providers to the embeddings query module
Adds support for any embedding provider supported by LiteLLM — OpenAI, Azure OpenAI, Ollama, Cohere, Voyage, Mistral, Jina, Bedrock, Vertex AI, Hugging Face, and any OpenAI‑compatible endpoint — while keeping the local sentence-transformers path unchanged.
Summary
Backwards compatibility
Zero Cypher changes required. CALL embeddings.node_sentence(null, {}) and CALL embeddings.text([...], {}) keep the local path with the same default model and return shape.
Tests