From 9184f05fc8bdba358b4fabb0d2b6087ec5ff2630 Mon Sep 17 00:00:00 2001 From: Teagan Glenn Date: Sat, 13 Sep 2025 23:07:40 -0600 Subject: [PATCH] refactor: lazily initialize reranker --- AGENTS.md | 2 ++ mcp_plex/server.py | 34 ++++++++++++++++++++-------------- pyproject.toml | 2 +- tests/test_server.py | 8 ++++---- uv.lock | 2 +- 5 files changed, 28 insertions(+), 20 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 9fb6540..b3049c0 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -13,6 +13,8 @@ to combine dense and sparse results before optional cross-encoder reranking. - Qdrant client initialization moved into `PlexServer` to centralize state and simplify testing. +- Cross-encoder reranker is initialized lazily via a `PlexServer` property to + avoid unnecessary model downloads when reranking is disabled or unavailable. ## User Queries The project should handle natural-language searches and recommendations such as: diff --git a/mcp_plex/server.py b/mcp_plex/server.py index d2cf832..a43b204 100644 --- a/mcp_plex/server.py +++ b/mcp_plex/server.py @@ -21,11 +21,6 @@ from starlette.requests import Request from starlette.responses import JSONResponse, PlainTextResponse, Response -try: - from sentence_transformers import CrossEncoder -except Exception: - CrossEncoder = None - try: from sentence_transformers import CrossEncoder except Exception: @@ -65,14 +60,24 @@ def __init__(self) -> None: # noqa: D401 - short description inherited prefer_grpc=_QDRANT_PREFER_GRPC, https=_QDRANT_HTTPS, ) + self._reranker: CrossEncoder | None = None + self._reranker_loaded = False + + @property + def reranker(self) -> CrossEncoder | None: + if not _USE_RERANKER or CrossEncoder is None: + return None + if not self._reranker_loaded: + try: + self._reranker = CrossEncoder( + "cross-encoder/ms-marco-MiniLM-L-6-v2" + ) + except Exception: + self._reranker = None + self._reranker_loaded = True + return self._reranker _USE_RERANKER = os.getenv("USE_RERANKER", "1") == "1" -_reranker = None -if _USE_RERANKER and CrossEncoder is not None: - try: - _reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") - except Exception: - _reranker = None server = PlexServer() @@ -193,7 +198,8 @@ async def search_media( """Hybrid similarity search across media items using dense and sparse vectors.""" dense_doc = models.Document(text=query, model=_DENSE_MODEL_NAME) sparse_doc = models.Document(text=query, model=_SPARSE_MODEL_NAME) - candidate_limit = limit * 3 if _reranker is not None else limit + reranker = server.reranker + candidate_limit = limit * 3 if reranker is not None else limit prefetch = [ models.Prefetch( query=models.NearestQuery(nearest=dense_doc), @@ -230,7 +236,7 @@ async def _prefetch(hit: models.ScoredPoint) -> None: prefetch_task = asyncio.gather(*[_prefetch(h) for h in hits[:limit]]) def _rerank(hits: list[models.ScoredPoint]) -> list[models.ScoredPoint]: - if _reranker is None: + if reranker is None: return hits docs: list[str] = [] for h in hits: @@ -244,7 +250,7 @@ def _rerank(hits: list[models.ScoredPoint]) -> list[models.ScoredPoint]: ] docs.append(" ".join(p for p in parts if p)) pairs = [(query, d) for d in docs] - scores = _reranker.predict(pairs) + scores = reranker.predict(pairs) for h, s in zip(hits, scores): h.score = float(s) hits.sort(key=lambda h: h.score, reverse=True) diff --git a/pyproject.toml b/pyproject.toml index d976ac0..ae54d1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mcp-plex" -version = "0.26.10" +version = "0.26.11" description = "Plex-Oriented Model Context Protocol Server" requires-python = ">=3.11,<3.13" diff --git a/tests/test_server.py b/tests/test_server.py index ba518b1..5121d9f 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -142,8 +142,8 @@ def fake_import(name, *args, **kwargs): return orig_import(name, *args, **kwargs) monkeypatch.setattr(builtins, "__import__", fake_import) - server = importlib.reload(importlib.import_module("mcp_plex.server")) - assert server._reranker is None + module = importlib.reload(importlib.import_module("mcp_plex.server")) + assert module.server.reranker is None def test_reranker_init_failure(monkeypatch): @@ -156,8 +156,8 @@ def __init__(self, *args, **kwargs): st_module.CrossEncoder = Broken monkeypatch.setitem(sys.modules, "sentence_transformers", st_module) - server = importlib.reload(importlib.import_module("mcp_plex.server")) - assert server._reranker is None + module = importlib.reload(importlib.import_module("mcp_plex.server")) + assert module.server.reranker is None def test_rest_endpoints(monkeypatch): diff --git a/uv.lock b/uv.lock index 817dfa1..b7e6863 100644 --- a/uv.lock +++ b/uv.lock @@ -690,7 +690,7 @@ wheels = [ [[package]] name = "mcp-plex" -version = "0.26.10" +version = "0.26.11" source = { editable = "." } dependencies = [ { name = "fastapi" },