Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/pyproject.deps.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "mcp-plex"
version = "2.0.13"
version = "2.0.14"
requires-python = ">=3.11,<3.13"
dependencies = [
"fastmcp>=2.11.2",
Expand Down
31 changes: 28 additions & 3 deletions mcp_plex/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def _lifespan(app: FastMCP) -> _ServerLifespan: # noqa: ARG001
super().__init__(lifespan=_lifespan)
self._reranker: CrossEncoder | None = None
self._reranker_loaded = False
self._reranker_lock = asyncio.Lock()
self._ensure_reranker_task: asyncio.Task[CrossEncoder | None] | None = None
self.cache = MediaCache(self.settings.cache_size)
self.client_identifier = uuid.uuid4().hex
self._plex_identity: dict[str, Any] | None = None
Expand Down Expand Up @@ -145,16 +147,39 @@ def settings(self) -> Settings: # type: ignore[override]
def reranker(self) -> CrossEncoder | None:
if not self.settings.use_reranker or CrossEncoder is None:
return None
if not self._reranker_loaded:
if self._reranker_loaded:
return self._reranker
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(self.ensure_reranker())
else:
if self._ensure_reranker_task is None or self._ensure_reranker_task.done():
self._ensure_reranker_task = loop.create_task(self.ensure_reranker())
return self._reranker

async def ensure_reranker(self) -> CrossEncoder | None:
if not self.settings.use_reranker or CrossEncoder is None:
self._reranker_loaded = True
self._reranker = None
return None
if self._reranker_loaded:
return self._reranker
async with self._reranker_lock:
if self._reranker_loaded:
return self._reranker
try:
self._reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
reranker = await asyncio.to_thread(
CrossEncoder, "cross-encoder/ms-marco-MiniLM-L-6-v2"
)
except Exception as exc:
logger.warning(
"Failed to initialize CrossEncoder reranker: %s",
exc,
exc_info=exc,
)
self._reranker = None
reranker = None
self._reranker = reranker
self._reranker_loaded = True
return self._reranker

Expand Down
2 changes: 1 addition & 1 deletion mcp_plex/server/tools/media_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async def search_media(

dense_doc = models.Document(text=query, model=server.settings.dense_model)
sparse_doc = models.Document(text=query, model=server.settings.sparse_model)
reranker = server.reranker
reranker = await server.ensure_reranker()
candidate_limit = limit * 3 if reranker is not None else limit
prefetch = [
models.Prefetch(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "mcp-plex"
version = "2.0.13"
version = "2.0.14"

description = "Plex-Oriented Model Context Protocol Server"
requires-python = ">=3.11,<3.13"
Expand Down
61 changes: 61 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@
from pydantic import ValidationError


def _reload_server_with_dummy_reranker(monkeypatch):
monkeypatch.setenv("USE_RERANKER", "1")
st_module = types.ModuleType("sentence_transformers")

class Dummy:
def __init__(self, model_id: str) -> None:
self.model_id = model_id

st_module.CrossEncoder = Dummy
monkeypatch.setitem(sys.modules, "sentence_transformers", st_module)
module = importlib.reload(importlib.import_module("mcp_plex.server"))
return module, Dummy


@contextmanager
def _load_server(monkeypatch):
from qdrant_client import async_qdrant_client
Expand Down Expand Up @@ -526,6 +540,53 @@ def __init__(self, *args, **kwargs):
asyncio.run(module.server.close())


def test_ensure_reranker_uses_thread_executor(monkeypatch):
module, Dummy = _reload_server_with_dummy_reranker(monkeypatch)
calls: list[tuple[object, tuple[object, ...], dict[str, object]]] = []

async def fake_to_thread(fn, *args, **kwargs): # type: ignore[no-untyped-def]
calls.append((fn, args, kwargs))
return fn(*args, **kwargs)

monkeypatch.setattr(asyncio, "to_thread", fake_to_thread)

async def exercise():
reranker = await module.server.ensure_reranker()
assert isinstance(reranker, Dummy)
assert reranker is await module.server.ensure_reranker()

asyncio.run(exercise())

assert len(calls) == 1
fn, args, _ = calls[0]
assert fn is Dummy
assert args == ("cross-encoder/ms-marco-MiniLM-L-6-v2",)
asyncio.run(module.server.close())


def test_ensure_reranker_concurrent_calls_share_single_instance(monkeypatch):
module, Dummy = _reload_server_with_dummy_reranker(monkeypatch)
call_count = 0

async def fake_to_thread(fn, *args, **kwargs): # type: ignore[no-untyped-def]
nonlocal call_count
call_count += 1
await asyncio.sleep(0)
return fn(*args, **kwargs)

monkeypatch.setattr(asyncio, "to_thread", fake_to_thread)

async def exercise():
results = await asyncio.gather(
*(module.server.ensure_reranker() for _ in range(5))
)
assert call_count == 1
assert all(reranker is results[0] for reranker in results)

asyncio.run(exercise())
asyncio.run(module.server.close())


def test_rest_endpoints(monkeypatch):
with _load_server(monkeypatch) as module:
client = TestClient(module.server.http_app())
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.