Skip to content

feat: Remote text embeddings compute#4064

Open
mattkjames7 wants to merge 8 commits intomasterfrom
feat/remote-embeddings
Open

feat: Remote text embeddings compute#4064
mattkjames7 wants to merge 8 commits intomasterfrom
feat/remote-embeddings

Conversation

@mattkjames7
Copy link
Copy Markdown
Contributor

@mattkjames7 mattkjames7 commented Apr 22, 2026

Add remote embedding providers to the embeddings query module

Adds support for any embedding provider supported by LiteLLM — OpenAI, Azure OpenAI, Ollama, Cohere, Voyage, Mistral, Jina, Bedrock, Vertex AI, Hugging Face, and any OpenAI‑compatible endpoint — while keeping the local sentence-transformers path unchanged.

Summary

  • Routing decided by model_name: a LiteLLM provider prefix (e.g. openai/text-embedding-3-small, ollama/nomic-embed-text) routes remote; bare names and HF paths (e.g. BAAI/bge-small-en-v1.5) stay local.
  • New config keys (only used on the remote path): api_base, input_type, dimensions, timeout, num_retries, normalize, remote_batch_size, concurrency.
  • Credentials read from the Memgraph process environment only (matches llm.py — OPENAI_API_KEY, COHERE_API_KEY, etc.); no api_key config key exists, so secrets cannot leak into Cypher or query logs.
  • Remote calls chunked and fanned out with a per-call ThreadPoolExecutor (default concurrency=4), L2-normalized client-side by default to match local behavior.
  • dimension return field is now Nullable[int] — procs surface success=false, dimension=null on failure rather than throwing.
  • embeddings.model_info probes remote providers once for their dimension, cached per (model_name, api_base).
  • Zero new dependencies — litellm is already pinned in mage/python/requirements*.txt (used by llm.py).

Backwards compatibility

Zero Cypher changes required. CALL embeddings.node_sentence(null, {}) and CALL embeddings.text([...], {}) keep the local path with the same default model and return shape.

Tests

  • Two new YAML e2e tests that exercise the remote route without network egress (api_base: "http://127.0.0.1:1" → ECONNREFUSED → graceful success=false). These run in CI.
  • Two opt-in pytest files for real providers, skipped unless OPENAI_API_KEY / MAGE_E2E_OLLAMA=1 are set. Not run in CI; intended for local verification.

@sonarqubecloud
Copy link
Copy Markdown

Quality Gate Failed Quality Gate failed

Failed conditions
1 Security Hotspot

See analysis details on SonarQube Cloud

@mattkjames7
Copy link
Copy Markdown
Contributor Author

mattkjames7 commented Apr 23, 2026

Tracking

  • [Link to Epic/Issue]

Standard development

CI Testing Labels

  • Select the appropriate CI test labels (CI -build=build-name -test=test-suite)

Documentation checklist

  • Add the documentation label
  • Add the bug / feature label
  • Add the milestone for which this feature is intended
    • If not known, set for a later milestone
  • Write a release note, including added/changed clauses
    • Added the ability to request text embeddings from remote APIs using litellm. Users should export the relevant key(s) for the API(s) that they intend to request embeddings from when launching the memgraph/memgraph-mage container, e.g. -e OPENAI_API_KEY=$OPENAI_API_KEY, then set model_name in the format "{provider}/{model-name}" (see here for valid providers and their models) #4064
  • feat: Add remote embedding computation to the embeddings module documentation#1598
    • Is back linked to this development PR

@mattkjames7 mattkjames7 marked this pull request as ready for review April 23, 2026 15:07
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds LiteLLM-backed remote embedding providers to the embeddings query module while preserving the existing local sentence-transformers execution path, plus CI-safe tests for remote failure behavior.

Changes:

  • Route embedding computation to LiteLLM remote providers based on model_name prefixes; keep bare/HF-style names local.
  • Add remote-only configuration options (e.g., api_base, timeout, num_retries, remote_batch_size, concurrency, normalize) and remote model dimension probing with caching.
  • Add e2e tests for graceful remote failures and opt-in pytest e2e tests for real OpenAI/Ollama providers.

Reviewed changes

Copilot reviewed 7 out of 8 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
release/package/mgbuild.sh Ensures build deps are installed in the container before building the gssapi wheel.
mage/python/embeddings.py Implements LiteLLM remote routing, batching/concurrency, client-side normalization, nullable dimension, and remote model-info probing/cache.
mage/tests/e2e/embeddings_test/test_remote_text_failure/test.yml CI-safe e2e test verifying remote embeddings.text failure returns success=false with null outputs.
mage/tests/e2e/embeddings_test/test_remote_node_sentence_failure/test.yml CI-safe e2e test verifying remote embeddings.node_sentence failure doesn’t write embeddings.
mage/tests/e2e/embeddings_test/test_remote_node_sentence_failure/input.cyp Fixture graph data for the remote node_sentence failure test.
mage/tests/e2e/embeddings_test/test_remote_openai.py Opt-in real-provider e2e tests for OpenAI embeddings/model_info/node_sentence.
mage/tests/e2e/embeddings_test/test_remote_ollama.py Opt-in real-provider e2e tests for Ollama embeddings/model_info/node_sentence.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread mage/python/embeddings.py
Comment on lines 763 to +790
@@ -591,6 +780,44 @@
return info


def _remote_model_info(configuration, resolved):
"""Probe a remote provider once to learn its embedding dimension.

Cached per (model_name, api_base) so subsequent calls don't re-hit the API.
"""
_model, _provider, default_api_base = resolved
api_base = configuration.get("api_base") or default_api_base
cache_key = (configuration["model_name"], api_base)
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_remote_info_cache is an unbounded global dict keyed by (model_name, api_base). In long-running Memgraph processes with many unique model/api_base values, this can grow without limit.

Consider adding a simple bound/eviction policy (e.g., LRU with a max size) or a TTL to avoid unbounded memory growth.

Copilot uses AI. Check for mistakes.
Comment on lines +25 to +28
pytestmark = pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY"),
reason="OPENAI_API_KEY not set — skipping opt-in OpenAI e2e test",
)
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This opt-in test runs whenever OPENAI_API_KEY is present in the pytest environment. In some developer/CI setups that key may be set for unrelated reasons, which could accidentally trigger real external API calls.

Consider gating with a dedicated opt-in env var (similar to MAGE_E2E_OLLAMA=1) in addition to OPENAI_API_KEY, so it only runs when explicitly requested.

Copilot uses AI. Check for mistakes.
Comment thread mage/python/embeddings.py
Comment on lines +502 to +505
with ThreadPoolExecutor(max_workers=max(1, cfg["concurrency"])) as ex:
fut2idx = {ex.submit(_call, c): i for i, c in enumerate(chunks)}
for fut in as_completed(fut2idx):
results[fut2idx[fut]] = fut.result()
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If any chunk future raises, fut.result() will propagate immediately, but the with ThreadPoolExecutor(...) context will still wait for all other in-flight futures to finish on exit. This can significantly increase end-to-end latency on partial failure (e.g., one chunk timing out) even though the call is already considered failed.

Consider catching exceptions per-future, cancelling remaining futures (when possible), and re-raising after shutdown with cancel_futures=True (or equivalent) so failures return promptly.

Suggested change
with ThreadPoolExecutor(max_workers=max(1, cfg["concurrency"])) as ex:
fut2idx = {ex.submit(_call, c): i for i, c in enumerate(chunks)}
for fut in as_completed(fut2idx):
results[fut2idx[fut]] = fut.result()
ex = ThreadPoolExecutor(max_workers=max(1, cfg["concurrency"]))
fut2idx = {}
failed = False
try:
fut2idx = {ex.submit(_call, c): i for i, c in enumerate(chunks)}
for fut in as_completed(fut2idx):
try:
results[fut2idx[fut]] = fut.result()
except Exception:
failed = True
for other in fut2idx:
if other is not fut:
other.cancel()
raise
finally:
if failed:
ex.shutdown(wait=False, cancel_futures=True)
else:
ex.shutdown(wait=True)

Copilot uses AI. Check for mistakes.
Comment thread mage/python/embeddings.py
Comment on lines 752 to 760
except Exception as e:
logger.error(f"Failed to compute embeddings: {e}")
return return_data(
input_items,
configuration["embedding_property"],
configuration["return_embeddings"],
False,
dimension=dimension,
)
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compute_embeddings() can raise UnboundLocalError in the outer except block because dimension is only assigned in the local-path branch. If an exception occurs before that assignment (e.g., during remote routing or config access), the error handler references an undefined dimension.

Initialize dimension = None before the try, or set dimension in all branches before any possible exception handler uses it.

Copilot uses AI. Check for mistakes.
Comment thread mage/python/embeddings.py
dimension=dimension,
)

chunk_size = cfg["remote_batch_size"] or default_remote_batch_size(provider)
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remote_batch_size is used as the range() step when chunking (range(0, n, chunk_size)). If a user passes remote_batch_size <= 0 (or a non-int), this will either raise (step=0) or silently produce no chunks (negative step), returning success=True without making any remote call.

Validate and coerce remote_batch_size (and possibly timeout/num_retries) in validate_configuration() to ensure it’s an integer >= 1 when set.

Suggested change
chunk_size = cfg["remote_batch_size"] or default_remote_batch_size(provider)
configured_chunk_size = cfg["remote_batch_size"]
if configured_chunk_size is None:
chunk_size = default_remote_batch_size(provider)
else:
try:
chunk_size = int(configured_chunk_size)
except (TypeError, ValueError):
raise ValueError("remote_batch_size must be an integer >= 1 when set.")
if chunk_size < 1:
raise ValueError("remote_batch_size must be an integer >= 1 when set.")

Copilot uses AI. Check for mistakes.
Comment thread mage/python/embeddings.py
Comment on lines +501 to +503
results = [None] * len(chunks)
with ThreadPoolExecutor(max_workers=max(1, cfg["concurrency"])) as ex:
fut2idx = {ex.submit(_call, c): i for i, c in enumerate(chunks)}
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ThreadPoolExecutor(max_workers=max(1, cfg["concurrency"])) will raise a TypeError if concurrency is provided as a non-int (e.g., string/float from Cypher). It also allows extreme values that could create too many threads inside the query module.

Consider validating/coercing concurrency in validate_configuration() to an int within a safe range (e.g., >=1 and <= some cap), and fail fast with a clear error/log message when invalid.

Copilot uses AI. Check for mistakes.
Comment thread mage/python/embeddings.py
Comment on lines +499 to +500
return [d["embedding"] for d in resp["data"]]

Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The remote embedding response is assumed to have 1:1 correspondence with the input chunk and to be in the same order ([d["embedding"] for d in resp["data"]]). Some providers return an index field or may not guarantee ordering; partial responses would also silently misalign embeddings when writing back to vertices.

Add a sanity check that len(resp["data"]) == len(chunk_texts) and, when available, sort by d["index"] (or equivalent) to preserve correct alignment before flattening/writing results.

Suggested change
return [d["embedding"] for d in resp["data"]]
data = resp["data"]
if len(data) != len(chunk_texts):
raise ValueError(
f"Embedding provider returned {len(data)} embeddings for {len(chunk_texts)} inputs."
)
def _item_index(item):
if isinstance(item, dict):
return item.get("index")
return getattr(item, "index", None)
if data and all(_item_index(item) is not None for item in data):
data = sorted(data, key=_item_index)
return [d["embedding"] for d in data]

Copilot uses AI. Check for mistakes.
Comment thread mage/python/embeddings.py
Comment on lines +615 to +618
api_base = cfg.get("api_base")
if not isinstance(api_base, str) or "@" not in api_base:
return cfg
return {**cfg, "api_base": re.sub(r"(://)[^/@\s]+:[^/@\s]+@", r"\1<redacted>@", api_base)}
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_redacted_config() only redacts credentials in the user:pass@host URL form. Secrets can still leak into logs if api_base contains tokens in the query string or path (e.g., ...?token=...), and this function currently returns the original config unchanged in those cases.

Consider parsing api_base with urllib.parse and redacting userinfo plus query/fragment (or whitelisting to scheme/host/port only) before logging.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants