diff --git a/mcp_plex/server.py b/mcp_plex/server.py index 435d328..f2661ed 100644 --- a/mcp_plex/server.py +++ b/mcp_plex/server.py @@ -8,11 +8,11 @@ from collections import OrderedDict from typing import Annotated, Any -from fastembed import SparseTextEmbedding, TextEmbedding from fastmcp.server import FastMCP from pydantic import Field from qdrant_client import models from qdrant_client.async_qdrant_client import AsyncQdrantClient +from qdrant_client.hybrid.fusion import reciprocal_rank_fusion try: from sentence_transformers import CrossEncoder @@ -43,7 +43,7 @@ if _QDRANT_URL is None and _QDRANT_HOST is None: _QDRANT_URL = ":memory:" -# Instantiate global client and embedding models +# Instantiate global client _client = AsyncQdrantClient( location=_QDRANT_URL, api_key=_QDRANT_API_KEY, @@ -53,8 +53,6 @@ prefer_grpc=_QDRANT_PREFER_GRPC, https=_QDRANT_HTTPS, ) -_dense_model = TextEmbedding(_DENSE_MODEL_NAME) -_sparse_model = SparseTextEmbedding(_SPARSE_MODEL_NAME) _USE_RERANKER = os.getenv("USE_RERANKER", "1") == "1" _reranker = None @@ -179,17 +177,26 @@ async def search_media( ] = 5, ) -> list[dict[str, Any]]: """Hybrid similarity search across media items using dense and sparse vectors.""" - dense_task = asyncio.to_thread(lambda: list(_dense_model.embed([query]))[0]) - dense_vec = await dense_task - named_dense = models.NamedVector(name="dense", vector=dense_vec) + dense_doc = models.Document(text=query, model=_DENSE_MODEL_NAME) + sparse_doc = models.Document(text=query, model=_SPARSE_MODEL_NAME) candidate_limit = limit * 3 if _reranker is not None else limit - hits = await _client.search( - collection_name="media-items", - query_vector=named_dense, - query_filter=None, - limit=candidate_limit, - with_payload=True, + dense_resp, sparse_resp = await asyncio.gather( + _client.query_points( + collection_name="media-items", + query=dense_doc, + using="dense", + limit=candidate_limit, + with_payload=True, + ), + _client.query_points( + collection_name="media-items", + query=sparse_doc, + using="sparse", + limit=candidate_limit, + with_payload=True, + ), ) + hits = reciprocal_rank_fusion([dense_resp.points, sparse_resp.points], limit=candidate_limit) async def _prefetch(hit: models.ScoredPoint) -> None: data = hit.payload["data"] @@ -468,6 +475,7 @@ async def media_background( def main(argv: list[str] | None = None) -> None: """CLI entrypoint for running the MCP server.""" + global _DENSE_MODEL_NAME, _SPARSE_MODEL_NAME parser = argparse.ArgumentParser(description="Run the MCP server") parser.add_argument("--bind", help="Host address to bind to") parser.add_argument("--port", type=int, help="Port to listen on") @@ -502,9 +510,8 @@ def main(argv: list[str] | None = None) -> None: if args.mount: run_kwargs["path"] = args.mount - global _dense_model, _sparse_model - _dense_model = TextEmbedding(args.dense_model) - _sparse_model = SparseTextEmbedding(args.sparse_model) + _DENSE_MODEL_NAME = args.dense_model + _SPARSE_MODEL_NAME = args.sparse_model server.run(transport=args.transport, **run_kwargs) diff --git a/pyproject.toml b/pyproject.toml index f875ed5..c7666a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mcp-plex" -version = "0.26.0" +version = "0.26.1" description = "Plex-Oriented Model Context Protocol Server" requires-python = ">=3.11,<4" diff --git a/tests/test_server_cli.py b/tests/test_server_cli.py index 7d61458..0b52068 100644 --- a/tests/test_server_cli.py +++ b/tests/test_server_cli.py @@ -2,28 +2,7 @@ import pytest -class _StubDense: - def __init__(self, *args, **kwargs) -> None: - pass - - @staticmethod - def list_supported_models() -> list[str]: - return ["stub-dense"] - - -class _StubSparse: - def __init__(self, *args, **kwargs) -> None: - pass - - @staticmethod - def list_supported_models() -> list[str]: - return ["stub"] - - -with patch("fastembed.TextEmbedding", _StubDense), patch( - "fastembed.SparseTextEmbedding", _StubSparse -): - from mcp_plex import server +from mcp_plex import server def test_main_stdio_runs(): @@ -51,32 +30,26 @@ def test_main_http_with_mount_runs(): def test_main_model_overrides(): - with patch("mcp_plex.server.TextEmbedding") as mock_dense, patch( - "mcp_plex.server.SparseTextEmbedding" - ) as mock_sparse, patch.object(server.server, "run") as mock_run: + with patch.object(server.server, "run") as mock_run: server.main([ "--dense-model", "foo", "--sparse-model", "bar", ]) - mock_dense.assert_called_with("foo") - mock_sparse.assert_called_with("bar") + assert server._DENSE_MODEL_NAME == "foo" + assert server._SPARSE_MODEL_NAME == "bar" mock_run.assert_called_once_with(transport="stdio") def test_env_model_overrides(monkeypatch): - with patch("fastembed.TextEmbedding") as mock_dense, patch( - "fastembed.SparseTextEmbedding" - ) as mock_sparse: - monkeypatch.setenv("DENSE_MODEL", "foo") - monkeypatch.setenv("SPARSE_MODEL", "bar") - import importlib - - importlib.reload(server) - mock_dense.assert_called_with("foo") - mock_sparse.assert_called_with("bar") - with patch("fastembed.TextEmbedding"), patch("fastembed.SparseTextEmbedding"): - import importlib - - importlib.reload(server) + monkeypatch.setenv("DENSE_MODEL", "foo") + monkeypatch.setenv("SPARSE_MODEL", "bar") + import importlib + + importlib.reload(server) + assert server._DENSE_MODEL_NAME == "foo" + assert server._SPARSE_MODEL_NAME == "bar" + + # reload to reset globals + importlib.reload(server) diff --git a/uv.lock b/uv.lock index fbf16df..e95edf7 100644 --- a/uv.lock +++ b/uv.lock @@ -801,7 +801,7 @@ wheels = [ [[package]] name = "mcp-plex" -version = "0.26.0" +version = "0.26.1" source = { editable = "." } dependencies = [ { name = "fastmcp" },