diff --git a/docker/pyproject.deps.toml b/docker/pyproject.deps.toml index cc1ae1e..2f30acb 100644 --- a/docker/pyproject.deps.toml +++ b/docker/pyproject.deps.toml @@ -1,6 +1,6 @@ [project] name = "mcp-plex" -version = "2.0.13" +version = "2.0.14" requires-python = ">=3.11,<3.13" dependencies = [ "fastmcp>=2.11.2", diff --git a/mcp_plex/server/__init__.py b/mcp_plex/server/__init__.py index 29381b8..461e033 100644 --- a/mcp_plex/server/__init__.py +++ b/mcp_plex/server/__init__.py @@ -103,6 +103,8 @@ def _lifespan(app: FastMCP) -> _ServerLifespan: # noqa: ARG001 super().__init__(lifespan=_lifespan) self._reranker: CrossEncoder | None = None self._reranker_loaded = False + self._reranker_lock = asyncio.Lock() + self._ensure_reranker_task: asyncio.Task[CrossEncoder | None] | None = None self.cache = MediaCache(self.settings.cache_size) self.client_identifier = uuid.uuid4().hex self._plex_identity: dict[str, Any] | None = None @@ -145,16 +147,39 @@ def settings(self) -> Settings: # type: ignore[override] def reranker(self) -> CrossEncoder | None: if not self.settings.use_reranker or CrossEncoder is None: return None - if not self._reranker_loaded: + if self._reranker_loaded: + return self._reranker + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(self.ensure_reranker()) + else: + if self._ensure_reranker_task is None or self._ensure_reranker_task.done(): + self._ensure_reranker_task = loop.create_task(self.ensure_reranker()) + return self._reranker + + async def ensure_reranker(self) -> CrossEncoder | None: + if not self.settings.use_reranker or CrossEncoder is None: + self._reranker_loaded = True + self._reranker = None + return None + if self._reranker_loaded: + return self._reranker + async with self._reranker_lock: + if self._reranker_loaded: + return self._reranker try: - self._reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") + reranker = await asyncio.to_thread( + CrossEncoder, "cross-encoder/ms-marco-MiniLM-L-6-v2" + ) except Exception as exc: logger.warning( "Failed to initialize CrossEncoder reranker: %s", exc, exc_info=exc, ) - self._reranker = None + reranker = None + self._reranker = reranker self._reranker_loaded = True return self._reranker diff --git a/mcp_plex/server/tools/media_library.py b/mcp_plex/server/tools/media_library.py index a09653b..ddabe71 100644 --- a/mcp_plex/server/tools/media_library.py +++ b/mcp_plex/server/tools/media_library.py @@ -87,7 +87,7 @@ async def search_media( dense_doc = models.Document(text=query, model=server.settings.dense_model) sparse_doc = models.Document(text=query, model=server.settings.sparse_model) - reranker = server.reranker + reranker = await server.ensure_reranker() candidate_limit = limit * 3 if reranker is not None else limit prefetch = [ models.Prefetch( diff --git a/pyproject.toml b/pyproject.toml index 67fc9ed..5060126 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mcp-plex" -version = "2.0.13" +version = "2.0.14" description = "Plex-Oriented Model Context Protocol Server" requires-python = ">=3.11,<3.13" diff --git a/tests/test_server.py b/tests/test_server.py index c948223..7a3f2cd 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -23,6 +23,20 @@ from pydantic import ValidationError +def _reload_server_with_dummy_reranker(monkeypatch): + monkeypatch.setenv("USE_RERANKER", "1") + st_module = types.ModuleType("sentence_transformers") + + class Dummy: + def __init__(self, model_id: str) -> None: + self.model_id = model_id + + st_module.CrossEncoder = Dummy + monkeypatch.setitem(sys.modules, "sentence_transformers", st_module) + module = importlib.reload(importlib.import_module("mcp_plex.server")) + return module, Dummy + + @contextmanager def _load_server(monkeypatch): from qdrant_client import async_qdrant_client @@ -526,6 +540,53 @@ def __init__(self, *args, **kwargs): asyncio.run(module.server.close()) +def test_ensure_reranker_uses_thread_executor(monkeypatch): + module, Dummy = _reload_server_with_dummy_reranker(monkeypatch) + calls: list[tuple[object, tuple[object, ...], dict[str, object]]] = [] + + async def fake_to_thread(fn, *args, **kwargs): # type: ignore[no-untyped-def] + calls.append((fn, args, kwargs)) + return fn(*args, **kwargs) + + monkeypatch.setattr(asyncio, "to_thread", fake_to_thread) + + async def exercise(): + reranker = await module.server.ensure_reranker() + assert isinstance(reranker, Dummy) + assert reranker is await module.server.ensure_reranker() + + asyncio.run(exercise()) + + assert len(calls) == 1 + fn, args, _ = calls[0] + assert fn is Dummy + assert args == ("cross-encoder/ms-marco-MiniLM-L-6-v2",) + asyncio.run(module.server.close()) + + +def test_ensure_reranker_concurrent_calls_share_single_instance(monkeypatch): + module, Dummy = _reload_server_with_dummy_reranker(monkeypatch) + call_count = 0 + + async def fake_to_thread(fn, *args, **kwargs): # type: ignore[no-untyped-def] + nonlocal call_count + call_count += 1 + await asyncio.sleep(0) + return fn(*args, **kwargs) + + monkeypatch.setattr(asyncio, "to_thread", fake_to_thread) + + async def exercise(): + results = await asyncio.gather( + *(module.server.ensure_reranker() for _ in range(5)) + ) + assert call_count == 1 + assert all(reranker is results[0] for reranker in results) + + asyncio.run(exercise()) + asyncio.run(module.server.close()) + + def test_rest_endpoints(monkeypatch): with _load_server(monkeypatch) as module: client = TestClient(module.server.http_app()) diff --git a/uv.lock b/uv.lock index 800152a..fb44430 100644 --- a/uv.lock +++ b/uv.lock @@ -730,7 +730,7 @@ wheels = [ [[package]] name = "mcp-plex" -version = "2.0.13" +version = "2.0.14" source = { editable = "." } dependencies = [ { name = "fastapi" },