Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 239 additions & 12 deletions mage/python/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
import os
import sys
import threading
from concurrent.futures import ProcessPoolExecutor, as_completed
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from typing import List

import huggingface_hub # noqa: F401
import litellm
import mgp

# Suppress LiteLLM's "Provider List: ..." banner that fires on every
# get_llm_provider() miss — we probe deliberately for local names.
litellm.suppress_debug_info = True

# We need to import huggingface_hub, otherwise sentence_transformers will fail to load the model.

sys.path.append(os.path.join(os.path.dirname(__file__), "embed_worker"))
Expand Down Expand Up @@ -200,7 +205,7 @@
batch_size: int = 2000,
return_embeddings: bool = False,
dimension: int = None,
) -> mgp.Record(success=bool, embeddings=mgp.Nullable[mgp.List[list]], dimension=int):
) -> mgp.Record(success=bool, embeddings=mgp.Nullable[mgp.List[list]], dimension=mgp.Nullable[int]):
model = _get_or_load_model(model_name, "cpu")
vertex_input = isinstance(embedding_property, str)
if vertex_input:
Expand Down Expand Up @@ -242,7 +247,7 @@
device: int = 0,
return_embeddings: bool = False,
dimension: int = None,
) -> mgp.Record(success=bool, embeddings=mgp.Nullable[mgp.List[list]], dimension=int):
) -> mgp.Record(success=bool, embeddings=mgp.Nullable[mgp.List[list]], dimension=mgp.Nullable[int]):
vertex_input = isinstance(embedding_property, str)
try:
model = _get_or_load_model(model_name, f"cuda:{device}")
Expand Down Expand Up @@ -305,7 +310,7 @@
gpus: List[int] = [0],
return_embeddings: bool = False,
dimension: int = None,
) -> mgp.Record(success=bool, embeddings=mgp.Nullable[mgp.List[list]], dimension=int):
) -> mgp.Record(success=bool, embeddings=mgp.Nullable[mgp.List[list]], dimension=mgp.Nullable[int]):
vertex_input = isinstance(embedding_property, str)

try:
Expand Down Expand Up @@ -407,6 +412,116 @@
)


def resolve_remote_model(model_name):
"""Ask LiteLLM whether ``model_name`` belongs to one of its providers.

Returns ``(model, provider, default_api_base)`` if recognized (e.g.
``"openai/text-embedding-3-small"``), otherwise ``None`` — in which case
the caller should fall through to the local SentenceTransformer path.
"""
if not isinstance(model_name, str) or not model_name:
return None
try:
model, provider, _dynamic_key, default_api_base = litellm.get_llm_provider(model_name)
except Exception:
return None
return model, provider, default_api_base


def default_remote_batch_size(provider):
"""Default `remote_batch_size` when the user doesn't set one.

Only lists providers with a *documented hard cap* — exceeding those
returns HTTP 400. Everything else falls through to a generic default.
Users override per-call via `remote_batch_size`.
"""
return {
"voyage": 1000, # docs.voyageai.com/docs/embeddings
"cohere": 96, # docs.cohere.com/reference/embed
"openai": 2048, # 2048 items + 300K tokens/request; token limit may bite first
"azure": 2048, # mirrors OpenAI
}.get(provider, 256)


def l2_normalize(vec):
s = sum(x * x for x in vec) ** 0.5
if s == 0:
return vec
return [x / s for x in vec]


def remote_compute(

Check failure on line 453 in mage/python/embeddings.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Refactor this function to reduce its Cognitive Complexity from 20 to the 15 allowed.

See more on https://sonarcloud.io/project/issues?id=memgraph_memgraph&issues=AZ22DjuAseNgnG04vSIX&open=AZ22DjuAseNgnG04vSIX&pullRequest=4064
input_items: mgp.Any,
cfg: mgp.Map,
dimension: int,
resolved,
) -> mgp.Record(success=bool, embeddings=mgp.Nullable[mgp.List[list]], dimension=mgp.Nullable[int]):
"""LiteLLM-routed remote embedding path.

Fans chunks across a small thread pool, preserves input order, and
L2-normalizes client-side when ``cfg["normalize"]`` (default) so behavior
matches the local sentence_transformers path. Raises on permanent failure
— the caller catches and converts to ``success=False``.
"""

_model, provider, default_api_base = resolved
vertex_input = isinstance(cfg["embedding_property"], str)
texts = build_texts(input_items, cfg["excluded_properties"]) if vertex_input else list(input_items)

n = len(texts)
if n == 0:
return return_data(
input_items if vertex_input else [],
embedding_property_name=cfg["embedding_property"] if vertex_input else None,
return_embeddings=cfg["return_embeddings"],
success=True,
dimension=dimension,
)

chunk_size = cfg["remote_batch_size"] or default_remote_batch_size(provider)
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remote_batch_size is used as the range() step when chunking (range(0, n, chunk_size)). If a user passes remote_batch_size <= 0 (or a non-int), this will either raise (step=0) or silently produce no chunks (negative step), returning success=True without making any remote call.

Validate and coerce remote_batch_size (and possibly timeout/num_retries) in validate_configuration() to ensure it’s an integer >= 1 when set.

Suggested change
chunk_size = cfg["remote_batch_size"] or default_remote_batch_size(provider)
configured_chunk_size = cfg["remote_batch_size"]
if configured_chunk_size is None:
chunk_size = default_remote_batch_size(provider)
else:
try:
chunk_size = int(configured_chunk_size)
except (TypeError, ValueError):
raise ValueError("remote_batch_size must be an integer >= 1 when set.")
if chunk_size < 1:
raise ValueError("remote_batch_size must be an integer >= 1 when set.")

Copilot uses AI. Check for mistakes.
chunks = [texts[i : i + chunk_size] for i in range(0, n, chunk_size)]

api_base = cfg["api_base"] or default_api_base
base_kwargs = {
"model": cfg["model_name"],
"num_retries": cfg["num_retries"],
"timeout": cfg["timeout"],
}
if api_base:
base_kwargs["api_base"] = api_base
if cfg["input_type"]:
base_kwargs["input_type"] = cfg["input_type"]
if cfg["dimensions"]:
base_kwargs["dimensions"] = cfg["dimensions"]

def _call(chunk_texts):
resp = litellm.embedding(input=chunk_texts, **base_kwargs)
return [d["embedding"] for d in resp["data"]]

Comment on lines +499 to +500
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The remote embedding response is assumed to have 1:1 correspondence with the input chunk and to be in the same order ([d["embedding"] for d in resp["data"]]). Some providers return an index field or may not guarantee ordering; partial responses would also silently misalign embeddings when writing back to vertices.

Add a sanity check that len(resp["data"]) == len(chunk_texts) and, when available, sort by d["index"] (or equivalent) to preserve correct alignment before flattening/writing results.

Suggested change
return [d["embedding"] for d in resp["data"]]
data = resp["data"]
if len(data) != len(chunk_texts):
raise ValueError(
f"Embedding provider returned {len(data)} embeddings for {len(chunk_texts)} inputs."
)
def _item_index(item):
if isinstance(item, dict):
return item.get("index")
return getattr(item, "index", None)
if data and all(_item_index(item) is not None for item in data):
data = sorted(data, key=_item_index)
return [d["embedding"] for d in data]

Copilot uses AI. Check for mistakes.
results = [None] * len(chunks)
with ThreadPoolExecutor(max_workers=max(1, cfg["concurrency"])) as ex:
fut2idx = {ex.submit(_call, c): i for i, c in enumerate(chunks)}
Comment on lines +501 to +503
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ThreadPoolExecutor(max_workers=max(1, cfg["concurrency"])) will raise a TypeError if concurrency is provided as a non-int (e.g., string/float from Cypher). It also allows extreme values that could create too many threads inside the query module.

Consider validating/coercing concurrency in validate_configuration() to an int within a safe range (e.g., >=1 and <= some cap), and fail fast with a clear error/log message when invalid.

Copilot uses AI. Check for mistakes.
for fut in as_completed(fut2idx):
results[fut2idx[fut]] = fut.result()
Comment on lines +502 to +505
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If any chunk future raises, fut.result() will propagate immediately, but the with ThreadPoolExecutor(...) context will still wait for all other in-flight futures to finish on exit. This can significantly increase end-to-end latency on partial failure (e.g., one chunk timing out) even though the call is already considered failed.

Consider catching exceptions per-future, cancelling remaining futures (when possible), and re-raising after shutdown with cancel_futures=True (or equivalent) so failures return promptly.

Suggested change
with ThreadPoolExecutor(max_workers=max(1, cfg["concurrency"])) as ex:
fut2idx = {ex.submit(_call, c): i for i, c in enumerate(chunks)}
for fut in as_completed(fut2idx):
results[fut2idx[fut]] = fut.result()
ex = ThreadPoolExecutor(max_workers=max(1, cfg["concurrency"]))
fut2idx = {}
failed = False
try:
fut2idx = {ex.submit(_call, c): i for i, c in enumerate(chunks)}
for fut in as_completed(fut2idx):
try:
results[fut2idx[fut]] = fut.result()
except Exception:
failed = True
for other in fut2idx:
if other is not fut:
other.cancel()
raise
finally:
if failed:
ex.shutdown(wait=False, cancel_futures=True)
else:
ex.shutdown(wait=True)

Copilot uses AI. Check for mistakes.

flat = [e for part in results for e in part]
if cfg["normalize"]:
flat = [l2_normalize(e) for e in flat]

if vertex_input:
for v, e in zip(input_items, flat):
v.properties[cfg["embedding_property"]] = e

logger.info(f"Processed {n} items via LiteLLM provider '{provider}' (model={cfg['model_name']}).")
return return_data(
input_items if vertex_input else flat,
embedding_property_name=cfg["embedding_property"] if vertex_input else None,
return_embeddings=cfg["return_embeddings"],
success=True,
dimension=dimension or (len(flat[0]) if flat else None),
)


def return_data(
input_items: mgp.Any,
embedding_property_name: mgp.Nullable[str] = "embedding",
Expand Down Expand Up @@ -440,13 +555,26 @@

def validate_configuration(configuration: mgp.Map):
default_configuration = {
# Local path
"embedding_property": "embedding",
"excluded_properties": ["embedding"],
"model_name": "all-MiniLM-L6-v2",
"batch_size": 2000,
"chunk_size": 48,
"device": None,
"return_embeddings": False,
# Remote path (via LiteLLM). These are only used when model_name resolves
# to a LiteLLM-known provider (e.g. "openai/text-embedding-3-small").
# API keys are NOT accepted here — LiteLLM reads canonical provider env
# vars (OPENAI_API_KEY, COHERE_API_KEY, VOYAGE_API_KEY, ...).
"api_base": None,
"input_type": "document",
"dimensions": None,
"timeout": 60,
"num_retries": 3,
"normalize": True,
"remote_batch_size": None,
"concurrency": 4,
}
configuration = {**default_configuration, **configuration}

Expand All @@ -463,21 +591,74 @@
if configuration["embedding_property"] is not None and configuration["embedding_property"] not in excluded:
excluded.append(configuration["embedding_property"])

logger.debug(f"Using embedding configuration: {configuration}")
# When routing remote, a local `device` setting has no meaning — warn so the
# user knows we're ignoring it rather than silently doing the wrong thing.
if resolve_remote_model(configuration["model_name"]) is not None and configuration["device"] is not None:
logger.warning(
f"'device' is ignored when model_name '{configuration['model_name']}' routes to a remote provider."
)

logger.debug(f"Using embedding configuration: {_redacted_config(configuration)}")

return configuration


def compute_embeddings(
def _redacted_config(cfg):
"""Return a shallow-copied config with URL-embedded credentials masked.

We don't accept secret-bearing config keys (credentials come from provider
env vars), so the only vector left is an ``api_base`` URL of the form
``https://user:pass@host/...`` — scrub just those.
"""
import re

api_base = cfg.get("api_base")
if not isinstance(api_base, str) or "@" not in api_base:
return cfg
return {**cfg, "api_base": re.sub(r"(://)[^/@\s]+:[^/@\s]+@", r"\1<redacted>@", api_base)}
Comment on lines +615 to +618
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_redacted_config() only redacts credentials in the user:pass@host URL form. Secrets can still leak into logs if api_base contains tokens in the query string or path (e.g., ...?token=...), and this function currently returns the original config unchanged in those cases.

Consider parsing api_base with urllib.parse and redacting userinfo plus query/fragment (or whitelisting to scheme/host/port only) before logging.

Copilot uses AI. Check for mistakes.


def compute_embeddings( # noqa: C901

Check failure on line 621 in mage/python/embeddings.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Refactor this function to reduce its Cognitive Complexity from 19 to the 15 allowed.

See more on https://sonarcloud.io/project/issues?id=memgraph_memgraph&issues=AZ22DjuAseNgnG04vSIZ&open=AZ22DjuAseNgnG04vSIZ&pullRequest=4064
input_items: mgp.Any,
configuration: mgp.Map,
) -> mgp.Any:
dimension = get_model_info(configuration).get("dimension", None)
if dimension is None:
logger.warning("Failed to get model dimension.")

try:
n = len(input_items)

# Route to LiteLLM if model_name names a known remote provider
# (e.g. "openai/text-embedding-3-small", "ollama/nomic-embed-text").
# Bare names and HF-style names like "BAAI/bge-small-en-v1.5" fall
# through to the local SentenceTransformer path below.
resolved = resolve_remote_model(configuration["model_name"])
if resolved is not None:
if n == 0:
logger.info("No items to process.")
return return_data(
input_items,
configuration["embedding_property"],
configuration["return_embeddings"],
True,
dimension=None,

Check failure on line 641 in mage/python/embeddings.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Change this argument; Function "return_data" expects a different type

See more on https://sonarcloud.io/project/issues?id=memgraph_memgraph&issues=AZ22DjuAseNgnG04vSIa&open=AZ22DjuAseNgnG04vSIa&pullRequest=4064
)
try:
# Dimension is derived from the encode response inside
# remote_compute — no separate probe needed.
return remote_compute(input_items, configuration, None, resolved)

Check failure on line 646 in mage/python/embeddings.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Change this argument; Function "remote_compute" expects a different type

See more on https://sonarcloud.io/project/issues?id=memgraph_memgraph&issues=AZ22DjuAseNgnG04vSIb&open=AZ22DjuAseNgnG04vSIb&pullRequest=4064
except Exception as e:
logger.error(f"Remote path failed: {e}")
return return_data(
input_items,
configuration["embedding_property"],
configuration["return_embeddings"],
False,
dimension=None,

Check failure on line 654 in mage/python/embeddings.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Change this argument; Function "return_data" expects a different type

See more on https://sonarcloud.io/project/issues?id=memgraph_memgraph&issues=AZ22DjuAseNgnG04vSIc&open=AZ22DjuAseNgnG04vSIc&pullRequest=4064
)

# Local path: probe the SentenceTransformer for dimension up front.
dimension = get_model_info(configuration).get("dimension", None)
if dimension is None:
logger.warning("Failed to get model dimension.")

if n == 0:
logger.info("No vertices to process.")
return return_data(
Expand Down Expand Up @@ -579,7 +760,15 @@
)


_remote_info_cache = {}
_remote_info_lock = threading.Lock()


def get_model_info(configuration: mgp.Map):
resolved = resolve_remote_model(configuration["model_name"])
if resolved is not None:
return _remote_model_info(configuration, resolved)

model = _get_or_load_model(configuration["model_name"], "cpu")

info = {
Expand All @@ -591,6 +780,44 @@
return info


def _remote_model_info(configuration, resolved):
"""Probe a remote provider once to learn its embedding dimension.

Cached per (model_name, api_base) so subsequent calls don't re-hit the API.
"""
_model, _provider, default_api_base = resolved
api_base = configuration.get("api_base") or default_api_base
cache_key = (configuration["model_name"], api_base)
Comment on lines 763 to +790
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_remote_info_cache is an unbounded global dict keyed by (model_name, api_base). In long-running Memgraph processes with many unique model/api_base values, this can grow without limit.

Consider adding a simple bound/eviction policy (e.g., LRU with a max size) or a TTL to avoid unbounded memory growth.

Copilot uses AI. Check for mistakes.

with _remote_info_lock:
cached = _remote_info_cache.get(cache_key)
if cached is not None:
return cached

kwargs = {
"model": configuration["model_name"],
"input": ["probe"],
"num_retries": configuration.get("num_retries", 3),
"timeout": configuration.get("timeout", 60),
}
if api_base:
kwargs["api_base"] = api_base
if configuration.get("dimensions"):
kwargs["dimensions"] = configuration["dimensions"]

resp = litellm.embedding(**kwargs)
dim = len(resp["data"][0]["embedding"])
info = {
"model_name": configuration["model_name"],
"dimension": dim,
"max_sequence_length": None,
}

with _remote_info_lock:
_remote_info_cache[cache_key] = info
return info


@mgp.read_proc
def model_info(
configuration: mgp.Map = {},
Expand All @@ -606,7 +833,7 @@
ctx: mgp.ProcCtx,
input_nodes: mgp.Nullable[mgp.List[mgp.Vertex]] = None,
configuration: mgp.Map = {},
) -> mgp.Record(success=bool, embeddings=mgp.Nullable[mgp.List[list]], dimension=int):
) -> mgp.Record(success=bool, embeddings=mgp.Nullable[mgp.List[list]], dimension=mgp.Nullable[int]):
logger.info(f"compute_embeddings: starting (py_exec={sys.executable}, py_ver={sys.version.split()[0]})")

configuration = validate_configuration(configuration)
Expand All @@ -623,7 +850,7 @@
ctx: mgp.ProcCtx,
input_strings: mgp.List[str],
configuration: mgp.Map = {},
) -> mgp.Record(success=bool, embeddings=mgp.Nullable[mgp.List[list]], dimension=int):
) -> mgp.Record(success=bool, embeddings=mgp.Nullable[mgp.List[list]], dimension=mgp.Nullable[int]):
logger.info(f"embed: starting (py_exec={sys.executable}, py_ver={sys.version.split()[0]})")

# hard code embedding_property to None for string input
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE (:Doc {title: "remote routing failure fixture"});
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
query: >
MATCH (n:Doc)
WITH n, {
model_name: "ollama/nomic-embed-text",
api_base: "http://127.0.0.1:1",
num_retries: 0,
timeout: 2
} AS configuration
CALL embeddings.node_sentence([n], configuration)
YIELD success
RETURN success, n.embedding IS NULL AS no_embedding_written;

output:
- success: false
no_embedding_written: true
Loading