From d80dd9ca047b676b5328cde35a7f6f513ea1a0f3 Mon Sep 17 00:00:00 2001 From: Teagan Glenn Date: Mon, 8 Sep 2025 23:23:36 -0600 Subject: [PATCH 1/2] feat: allow embedding model configuration --- AGENTS.md | 2 ++ README.md | 12 +++++++++ mcp_plex/loader.py | 30 +++++++++++++++++++-- mcp_plex/server.py | 24 +++++++++++++++-- pyproject.toml | 2 +- tests/test_loader_cli.py | 52 ++++++++++++++++++++++++++++++++++++ tests/test_server_cli.py | 57 +++++++++++++++++++++++++++++++++++++++- uv.lock | 6 ++--- 8 files changed, 176 insertions(+), 9 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index f2eacae..ce226b2 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -6,6 +6,8 @@ - `mcp_plex/types.py` defines the Pydantic models used across the project. - When making architectural design decisions, add a short note here describing the decision and its rationale. - Actor names are stored as a top-level payload field and indexed in Qdrant to enable actor and year-based filtering. +- Dense and sparse embedding model names are configurable via `DENSE_MODEL` and + `SPARSE_MODEL` environment variables or the corresponding CLI options. ## User Queries The project should handle natural-language searches and recommendations such as: diff --git a/README.md b/README.md index 06a9ed2..f4673f3 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,18 @@ Expose the server over SSE on port 8000: uv run mcp-server --transport sse --bind 0.0.0.0 --port 8000 --mount /mcp ``` +### Embedding Models + +Both the loader and server default to `BAAI/bge-small-en-v1.5` for dense +embeddings and `Qdrant/bm42-all-minilm-l6-v2-attentions` for sparse embeddings. +Override these by setting `DENSE_MODEL`/`SPARSE_MODEL` environment variables or +using `--dense-model`/`--sparse-model` CLI options: + +```bash +uv run load-data --dense-model my-dense --sparse-model my-sparse +uv run mcp-server --dense-model my-dense --sparse-model my-sparse +``` + ## Docker A Dockerfile builds a GPU-enabled image based on `nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04` using `uv` for dependency diff --git a/mcp_plex/loader.py b/mcp_plex/loader.py index 7e5bd1d..36a7807 100644 --- a/mcp_plex/loader.py +++ b/mcp_plex/loader.py @@ -293,6 +293,8 @@ async def run( qdrant_grpc_port: int = 6334, qdrant_https: bool = False, qdrant_prefer_grpc: bool = False, + dense_model_name: str = "BAAI/bge-small-en-v1.5", + sparse_model_name: str = "Qdrant/bm42-all-minilm-l6-v2-attentions", ) -> None: """Core execution logic for the CLI.""" @@ -325,8 +327,8 @@ async def run( parts.extend(r.get("content", "") for r in getattr(item.tmdb, "reviews", [])) texts.append("\n".join(p for p in parts if p)) - dense_model = TextEmbedding("BAAI/bge-small-en-v1.5") - sparse_model = SparseTextEmbedding("Qdrant/bm42-all-minilm-l6-v2-attentions") + dense_model = TextEmbedding(dense_model_name) + sparse_model = SparseTextEmbedding(sparse_model_name) dense_vectors = list(dense_model.embed(texts)) sparse_vectors = list(sparse_model.passage_embed(texts)) @@ -533,6 +535,22 @@ async def run( default=False, help="Prefer gRPC when connecting to Qdrant", ) +@click.option( + "--dense-model", + envvar="DENSE_MODEL", + show_envvar=True, + default="BAAI/bge-small-en-v1.5", + show_default=True, + help="Dense embedding model name", +) +@click.option( + "--sparse-model", + envvar="SPARSE_MODEL", + show_envvar=True, + default="Qdrant/bm42-all-minilm-l6-v2-attentions", + show_default=True, + help="Sparse embedding model name", +) @click.option( "--continuous", is_flag=True, @@ -561,6 +579,8 @@ def main( qdrant_grpc_port: int, qdrant_https: bool, qdrant_prefer_grpc: bool, + dense_model: str, + sparse_model: str, continuous: bool, delay: float, ) -> None: @@ -579,6 +599,8 @@ def main( qdrant_grpc_port, qdrant_https, qdrant_prefer_grpc, + dense_model, + sparse_model, continuous, delay, ) @@ -597,6 +619,8 @@ async def load_media( qdrant_grpc_port: int, qdrant_https: bool, qdrant_prefer_grpc: bool, + dense_model_name: str, + sparse_model_name: str, continuous: bool, delay: float, ) -> None: @@ -615,6 +639,8 @@ async def load_media( qdrant_grpc_port, qdrant_https, qdrant_prefer_grpc, + dense_model_name, + sparse_model_name, ) if not continuous: break diff --git a/mcp_plex/server.py b/mcp_plex/server.py index 5ac67ed..435d328 100644 --- a/mcp_plex/server.py +++ b/mcp_plex/server.py @@ -34,6 +34,12 @@ _https_env = os.getenv("QDRANT_HTTPS") _QDRANT_HTTPS = None if _https_env is None else _https_env == "1" +# Embedding model configuration +_DENSE_MODEL_NAME = os.getenv("DENSE_MODEL", "BAAI/bge-small-en-v1.5") +_SPARSE_MODEL_NAME = os.getenv( + "SPARSE_MODEL", "Qdrant/bm42-all-minilm-l6-v2-attentions" +) + if _QDRANT_URL is None and _QDRANT_HOST is None: _QDRANT_URL = ":memory:" @@ -47,8 +53,8 @@ prefer_grpc=_QDRANT_PREFER_GRPC, https=_QDRANT_HTTPS, ) -_dense_model = TextEmbedding("BAAI/bge-small-en-v1.5") -_sparse_model = SparseTextEmbedding("Qdrant/bm42-all-minilm-l6-v2-attentions") +_dense_model = TextEmbedding(_DENSE_MODEL_NAME) +_sparse_model = SparseTextEmbedding(_SPARSE_MODEL_NAME) _USE_RERANKER = os.getenv("USE_RERANKER", "1") == "1" _reranker = None @@ -472,6 +478,16 @@ def main(argv: list[str] | None = None) -> None: help="Transport protocol to use", ) parser.add_argument("--mount", help="Mount path for HTTP transports") + parser.add_argument( + "--dense-model", + default=_DENSE_MODEL_NAME, + help="Dense embedding model name (env: DENSE_MODEL)", + ) + parser.add_argument( + "--sparse-model", + default=_SPARSE_MODEL_NAME, + help="Sparse embedding model name (env: SPARSE_MODEL)", + ) args = parser.parse_args(argv) if args.transport != "stdio": @@ -486,6 +502,10 @@ def main(argv: list[str] | None = None) -> None: if args.mount: run_kwargs["path"] = args.mount + global _dense_model, _sparse_model + _dense_model = TextEmbedding(args.dense_model) + _sparse_model = SparseTextEmbedding(args.sparse_model) + server.run(transport=args.transport, **run_kwargs) diff --git a/pyproject.toml b/pyproject.toml index 51e902d..f875ed5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mcp-plex" -version = "0.25.0" +version = "0.26.0" description = "Plex-Oriented Model Context Protocol Server" requires-python = ">=3.11,<4" diff --git a/tests/test_loader_cli.py b/tests/test_loader_cli.py index da46024..0214d23 100644 --- a/tests/test_loader_cli.py +++ b/tests/test_loader_cli.py @@ -62,3 +62,55 @@ async def invoke(): with pytest.raises(RuntimeError, match="PLEX_URL and PLEX_TOKEN must be provided"): asyncio.run(invoke()) + + +def test_cli_model_overrides(monkeypatch): + captured: dict[str, str] = {} + + async def fake_run(*args, **kwargs): + captured["dense"] = args[-2] + captured["sparse"] = args[-1] + + monkeypatch.setattr(loader, "run", fake_run) + + runner = CliRunner() + runner.invoke( + loader.main, + ["--dense-model", "foo", "--sparse-model", "bar"], + catch_exceptions=False, + env={ + "PLEX_URL": "http://localhost", + "PLEX_TOKEN": "token", + "TMDB_API_KEY": "key", + }, + ) + + assert captured["dense"] == "foo" + assert captured["sparse"] == "bar" + + +def test_cli_model_env(monkeypatch): + captured: dict[str, str] = {} + + async def fake_run(*args, **kwargs): + captured["dense"] = args[-2] + captured["sparse"] = args[-1] + + monkeypatch.setattr(loader, "run", fake_run) + + runner = CliRunner() + runner.invoke( + loader.main, + [], + catch_exceptions=False, + env={ + "PLEX_URL": "http://localhost", + "PLEX_TOKEN": "token", + "TMDB_API_KEY": "key", + "DENSE_MODEL": "foo", + "SPARSE_MODEL": "bar", + }, + ) + + assert captured["dense"] == "foo" + assert captured["sparse"] == "bar" diff --git a/tests/test_server_cli.py b/tests/test_server_cli.py index c7fdff5..d8701b2 100644 --- a/tests/test_server_cli.py +++ b/tests/test_server_cli.py @@ -1,8 +1,31 @@ from unittest.mock import patch +from unittest.mock import patch + import pytest -from mcp_plex import server +class _StubDense: + def __init__(self, *args, **kwargs) -> None: + pass + + @staticmethod + def list_supported_models() -> list[str]: + return ["stub-dense"] + + +class _StubSparse: + def __init__(self, *args, **kwargs) -> None: + pass + + @staticmethod + def list_supported_models() -> list[str]: + return ["stub"] + + +with patch("fastembed.TextEmbedding", _StubDense), patch( + "fastembed.SparseTextEmbedding", _StubSparse +): + from mcp_plex import server def test_main_stdio_runs(): @@ -27,3 +50,35 @@ def test_main_http_with_mount_runs(): with patch.object(server.server, "run") as mock_run: server.main(["--transport", "sse", "--bind", "0.0.0.0", "--port", "8000", "--mount", "/mcp"]) mock_run.assert_called_once_with(transport="sse", host="0.0.0.0", port=8000, path="/mcp") + + +def test_main_model_overrides(): + with patch("mcp_plex.server.TextEmbedding") as mock_dense, patch( + "mcp_plex.server.SparseTextEmbedding" + ) as mock_sparse, patch.object(server.server, "run") as mock_run: + server.main([ + "--dense-model", + "foo", + "--sparse-model", + "bar", + ]) + mock_dense.assert_called_with("foo") + mock_sparse.assert_called_with("bar") + mock_run.assert_called_once_with(transport="stdio") + + +def test_env_model_overrides(monkeypatch): + with patch("fastembed.TextEmbedding") as mock_dense, patch( + "fastembed.SparseTextEmbedding" + ) as mock_sparse: + monkeypatch.setenv("DENSE_MODEL", "foo") + monkeypatch.setenv("SPARSE_MODEL", "bar") + import importlib + + importlib.reload(server) + mock_dense.assert_called_with("foo") + mock_sparse.assert_called_with("bar") + with patch("fastembed.TextEmbedding"), patch("fastembed.SparseTextEmbedding"): + import importlib + + importlib.reload(server) diff --git a/uv.lock b/uv.lock index 0c4fe0a..fbf16df 100644 --- a/uv.lock +++ b/uv.lock @@ -801,7 +801,7 @@ wheels = [ [[package]] name = "mcp-plex" -version = "0.25.0" +version = "0.26.0" source = { editable = "." } dependencies = [ { name = "fastmcp" }, @@ -812,13 +812,13 @@ dependencies = [ { name = "qdrant-client", version = "1.15.1", source = { registry = "https://pypi.org/simple" }, extra = ["fastembed-gpu"], marker = "python_full_version < '3.13'" }, { name = "rapidfuzz" }, { name = "scikit-learn" }, + { name = "sentence-transformers" }, ] [package.optional-dependencies] dev = [ { name = "pytest" }, { name = "ruff" }, - { name = "sentence-transformers" }, ] [package.metadata] @@ -832,7 +832,7 @@ requires-dist = [ { name = "rapidfuzz", specifier = ">=3.13.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.12.0" }, { name = "scikit-learn", specifier = ">=1.7.1" }, - { name = "sentence-transformers", marker = "extra == 'dev'", specifier = ">=2.7.0" }, + { name = "sentence-transformers", specifier = ">=2.7.0" }, ] provides-extras = ["dev"] From 74a9245d37da4fcc60c65f75c27f6651bd9d559f Mon Sep 17 00:00:00 2001 From: Teagan Glenn Date: Mon, 8 Sep 2025 23:30:19 -0600 Subject: [PATCH 2/2] test: remove duplicate patch import --- tests/test_server_cli.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_server_cli.py b/tests/test_server_cli.py index d8701b2..7d61458 100644 --- a/tests/test_server_cli.py +++ b/tests/test_server_cli.py @@ -1,7 +1,5 @@ from unittest.mock import patch -from unittest.mock import patch - import pytest class _StubDense: