Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
to combine dense and sparse results before optional cross-encoder reranking.
- Qdrant client initialization moved into `PlexServer` to centralize state and
simplify testing.
- Cross-encoder reranker is initialized lazily via a `PlexServer` property to
avoid unnecessary model downloads when reranking is disabled or unavailable.
- Media payload and artwork caching centralized in `MediaCache` attached to
`PlexServer` to streamline cache management and testing.

Expand Down
35 changes: 21 additions & 14 deletions mcp_plex/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@

from .cache import MediaCache

try:
from sentence_transformers import CrossEncoder
except Exception:
CrossEncoder = None

try:
from sentence_transformers import CrossEncoder
except Exception:
Expand Down Expand Up @@ -69,15 +64,26 @@ def __init__(self) -> None: # noqa: D401 - short description inherited
prefer_grpc=_QDRANT_PREFER_GRPC,
https=_QDRANT_HTTPS,
)
self._reranker: CrossEncoder | None = None
self._reranker_loaded = False
self.cache = MediaCache(_CACHE_SIZE)

@property
def reranker(self) -> CrossEncoder | None:
if not _USE_RERANKER or CrossEncoder is None:
return None
if not self._reranker_loaded:
try:
self._reranker = CrossEncoder(
"cross-encoder/ms-marco-MiniLM-L-6-v2"
)
except Exception:
self._reranker = None
self._reranker_loaded = True
return self._reranker


_USE_RERANKER = os.getenv("USE_RERANKER", "1") == "1"
_reranker = None
if _USE_RERANKER and CrossEncoder is not None:
try:
_reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
except Exception:
_reranker = None

server = PlexServer()

Expand Down Expand Up @@ -177,7 +183,8 @@ async def search_media(
"""Hybrid similarity search across media items using dense and sparse vectors."""
dense_doc = models.Document(text=query, model=_DENSE_MODEL_NAME)
sparse_doc = models.Document(text=query, model=_SPARSE_MODEL_NAME)
candidate_limit = limit * 3 if _reranker is not None else limit
reranker = server.reranker
candidate_limit = limit * 3 if reranker is not None else limit
prefetch = [
models.Prefetch(
query=models.NearestQuery(nearest=dense_doc),
Expand Down Expand Up @@ -214,7 +221,7 @@ async def _prefetch(hit: models.ScoredPoint) -> None:
prefetch_task = asyncio.gather(*[_prefetch(h) for h in hits[:limit]])

def _rerank(hits: list[models.ScoredPoint]) -> list[models.ScoredPoint]:
if _reranker is None:
if reranker is None:
return hits
docs: list[str] = []
for h in hits:
Expand All @@ -228,7 +235,7 @@ def _rerank(hits: list[models.ScoredPoint]) -> list[models.ScoredPoint]:
]
docs.append(" ".join(p for p in parts if p))
pairs = [(query, d) for d in docs]
scores = _reranker.predict(pairs)
scores = reranker.predict(pairs)
for h, s in zip(hits, scores):
h.score = float(s)
hits.sort(key=lambda h: h.score, reverse=True)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def fake_import(name, *args, **kwargs):
return orig_import(name, *args, **kwargs)

monkeypatch.setattr(builtins, "__import__", fake_import)
server = importlib.reload(importlib.import_module("mcp_plex.server"))
assert server._reranker is None
module = importlib.reload(importlib.import_module("mcp_plex.server"))
assert module.server.reranker is None


def test_reranker_init_failure(monkeypatch):
Expand All @@ -161,8 +161,8 @@ def __init__(self, *args, **kwargs):

st_module.CrossEncoder = Broken
monkeypatch.setitem(sys.modules, "sentence_transformers", st_module)
server = importlib.reload(importlib.import_module("mcp_plex.server"))
assert server._reranker is None
module = importlib.reload(importlib.import_module("mcp_plex.server"))
assert module.server.reranker is None


def test_rest_endpoints(monkeypatch):
Expand Down