diff --git a/AGENTS.md b/AGENTS.md index a7895ec..df937f8 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -13,6 +13,8 @@ to combine dense and sparse results before optional cross-encoder reranking. - Qdrant client initialization moved into `PlexServer` to centralize state and simplify testing. +- Cross-encoder reranker is initialized lazily via a `PlexServer` property to + avoid unnecessary model downloads when reranking is disabled or unavailable. - Media payload and artwork caching centralized in `MediaCache` attached to `PlexServer` to streamline cache management and testing. diff --git a/mcp_plex/server.py b/mcp_plex/server.py index 605d847..82a771d 100644 --- a/mcp_plex/server.py +++ b/mcp_plex/server.py @@ -22,11 +22,6 @@ from .cache import MediaCache -try: - from sentence_transformers import CrossEncoder -except Exception: - CrossEncoder = None - try: from sentence_transformers import CrossEncoder except Exception: @@ -69,15 +64,26 @@ def __init__(self) -> None: # noqa: D401 - short description inherited prefer_grpc=_QDRANT_PREFER_GRPC, https=_QDRANT_HTTPS, ) + self._reranker: CrossEncoder | None = None + self._reranker_loaded = False self.cache = MediaCache(_CACHE_SIZE) + @property + def reranker(self) -> CrossEncoder | None: + if not _USE_RERANKER or CrossEncoder is None: + return None + if not self._reranker_loaded: + try: + self._reranker = CrossEncoder( + "cross-encoder/ms-marco-MiniLM-L-6-v2" + ) + except Exception: + self._reranker = None + self._reranker_loaded = True + return self._reranker + + _USE_RERANKER = os.getenv("USE_RERANKER", "1") == "1" -_reranker = None -if _USE_RERANKER and CrossEncoder is not None: - try: - _reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") - except Exception: - _reranker = None server = PlexServer() @@ -177,7 +183,8 @@ async def search_media( """Hybrid similarity search across media items using dense and sparse vectors.""" dense_doc = models.Document(text=query, model=_DENSE_MODEL_NAME) sparse_doc = models.Document(text=query, model=_SPARSE_MODEL_NAME) - candidate_limit = limit * 3 if _reranker is not None else limit + reranker = server.reranker + candidate_limit = limit * 3 if reranker is not None else limit prefetch = [ models.Prefetch( query=models.NearestQuery(nearest=dense_doc), @@ -214,7 +221,7 @@ async def _prefetch(hit: models.ScoredPoint) -> None: prefetch_task = asyncio.gather(*[_prefetch(h) for h in hits[:limit]]) def _rerank(hits: list[models.ScoredPoint]) -> list[models.ScoredPoint]: - if _reranker is None: + if reranker is None: return hits docs: list[str] = [] for h in hits: @@ -228,7 +235,7 @@ def _rerank(hits: list[models.ScoredPoint]) -> list[models.ScoredPoint]: ] docs.append(" ".join(p for p in parts if p)) pairs = [(query, d) for d in docs] - scores = _reranker.predict(pairs) + scores = reranker.predict(pairs) for h, s in zip(hits, scores): h.score = float(s) hits.sort(key=lambda h: h.score, reverse=True) diff --git a/tests/test_server.py b/tests/test_server.py index c3ff03a..04360e6 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -147,8 +147,8 @@ def fake_import(name, *args, **kwargs): return orig_import(name, *args, **kwargs) monkeypatch.setattr(builtins, "__import__", fake_import) - server = importlib.reload(importlib.import_module("mcp_plex.server")) - assert server._reranker is None + module = importlib.reload(importlib.import_module("mcp_plex.server")) + assert module.server.reranker is None def test_reranker_init_failure(monkeypatch): @@ -161,8 +161,8 @@ def __init__(self, *args, **kwargs): st_module.CrossEncoder = Broken monkeypatch.setitem(sys.modules, "sentence_transformers", st_module) - server = importlib.reload(importlib.import_module("mcp_plex.server")) - assert server._reranker is None + module = importlib.reload(importlib.import_module("mcp_plex.server")) + assert module.server.reranker is None def test_rest_endpoints(monkeypatch):