From 9f60ed5bd73dbdb8c59cb06ee2ba77f359bb2810 Mon Sep 17 00:00:00 2001 From: Teagan Glenn Date: Sun, 31 Aug 2025 00:58:17 -0600 Subject: [PATCH 1/3] fix(server): unwrap sparse query embedding --- AGENTS.md | 1 + mcp_plex/server.py | 38 +++++++++++- tests/test_server.py | 134 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 165 insertions(+), 8 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 82bd25a..b20cb21 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -14,6 +14,7 @@ ## Testing Practices - Use realistic (or as realistic as possible) data in tests; avoid meaningless placeholder values. - Always test both positive and negative logical paths. +- Do **not** use `# pragma: no cover`; add tests to exercise code paths instead. ## Efficiency and Search - Use `rg` (ripgrep) for recursive search. diff --git a/mcp_plex/server.py b/mcp_plex/server.py index 9a0da73..7767c2a 100644 --- a/mcp_plex/server.py +++ b/mcp_plex/server.py @@ -13,6 +13,11 @@ from fastembed import TextEmbedding, SparseTextEmbedding from pydantic import Field +try: + from sentence_transformers import CrossEncoder +except Exception: + CrossEncoder = None + # Environment configuration for Qdrant _QDRANT_URL = os.getenv("QDRANT_URL", ":memory:") _QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") @@ -22,6 +27,14 @@ _dense_model = TextEmbedding("BAAI/bge-small-en-v1.5") _sparse_model = SparseTextEmbedding("Qdrant/bm42-all-minilm-l6-v2-attentions") +_USE_RERANKER = os.getenv("USE_RERANKER", "1") == "1" +_reranker = None +if _USE_RERANKER and CrossEncoder is not None: + try: + _reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") + except Exception: + _reranker = None + server = FastMCP() @@ -138,18 +151,19 @@ async def search_media( ) -> list[dict[str, Any]]: """Hybrid similarity search across media items using dense and sparse vectors.""" dense_task = asyncio.to_thread(lambda: list(_dense_model.embed([query]))[0]) - sparse_task = asyncio.to_thread(lambda: _sparse_model.query_embed(query)) + sparse_task = asyncio.to_thread(lambda: next(_sparse_model.query_embed(query))) dense_vec, sparse_vec = await asyncio.gather(dense_task, sparse_task) named_dense = models.NamedVector(name="dense", vector=dense_vec) sv = models.SparseVector( indices=sparse_vec.indices.tolist(), values=sparse_vec.values.tolist() ) named_sparse = models.NamedSparseVector(name="sparse", vector=sv) + candidate_limit = limit * 3 if _reranker is not None else limit hits = await _client.search( collection_name="media-items", query_vector=named_dense, query_sparse_vector=named_sparse, - limit=limit * 3, + limit=candidate_limit, with_payload=True, ) @@ -168,6 +182,24 @@ async def _prefetch(hit: models.ScoredPoint) -> None: prefetch_task = asyncio.gather(*[_prefetch(h) for h in hits[:limit]]) def _rerank(hits: list[models.ScoredPoint]) -> list[models.ScoredPoint]: + if _reranker is None: + return hits + docs: list[str] = [] + for h in hits: + data = h.payload["data"] + parts = [ + data.get("title"), + data.get("summary"), + data.get("plex", {}).get("title"), + data.get("plex", {}).get("summary"), + data.get("tmdb", {}).get("overview"), + ] + docs.append(" ".join(p for p in parts if p)) + pairs = [(query, d) for d in docs] + scores = _reranker.predict(pairs) + for h, s in zip(hits, scores): + h.score = float(s) + hits.sort(key=lambda h: h.score, reverse=True) return hits reranked = await asyncio.to_thread(_rerank, hits) @@ -290,5 +322,5 @@ async def media_background( return art -if __name__ == "__main__": # pragma: no cover +if __name__ == "__main__": server.run() diff --git a/tests/test_server.py b/tests/test_server.py index 6594af0..c36ee17 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -5,6 +5,7 @@ import types import json import time +import sys import pytest from mcp_plex import loader @@ -50,7 +51,25 @@ def passage_embed(self, texts): def query_embed(self, text): time.sleep(0.1) - return DummySparseVector([0], [1.0]) + yield DummySparseVector([0], [1.0]) + + +class DummyReranker: + def __init__(self, model: str): + pass + + def predict(self, pairs): + scores = [] + for _, doc in pairs: + if "Gentlemen" in doc: + scores.append(10) + elif "C" in doc: + scores.append(2) + elif "B" in doc: + scores.append(1) + else: + scores.append(0) + return scores class DummyQdrantClient: @@ -100,7 +119,13 @@ def matches(rec, cond): return records[:limit], None async def search(self, collection_name: str, query_vector, query_sparse_vector=None, limit: int = 5, with_payload=False, **kwargs): - return list(self.store.values())[:limit] + records = list(self.store.values())[:limit] + return [ + models.ScoredPoint( + id=r.id, version=1, score=0.0, payload=r.payload, vector=None + ) + for r in records + ] async def recommend(self, collection_name: str, positive, limit: int = 5, with_payload=False, **kwargs): return [r for r in self.store.values() if r.id not in positive][:limit] @@ -128,6 +153,10 @@ def test_server_tools(tmp_path, monkeypatch): monkeypatch.setattr(fastembed, "TextEmbedding", DummyTextEmbedding) monkeypatch.setattr(fastembed, "SparseTextEmbedding", DummySparseEmbedding) monkeypatch.setattr(async_qdrant_client, "AsyncQdrantClient", DummyQdrantClient) + monkeypatch.setenv("USE_RERANKER", "1") + st_module = types.ModuleType("sentence_transformers") + st_module.CrossEncoder = DummyReranker + monkeypatch.setitem(sys.modules, "sentence_transformers", st_module) asyncio.run(_setup_db(tmp_path)) server = importlib.reload(importlib.import_module("mcp_plex.server")) @@ -142,6 +171,13 @@ def test_server_tools(tmp_path, monkeypatch): res = asyncio.run(server.get_media.fn(identifier="The Gentlemen")) assert res and res[0]["plex"]["rating_key"] == movie_id + poster = asyncio.run(server.media_poster.fn(identifier=movie_id)) + assert isinstance(poster, str) and "thumb" in poster + art = asyncio.run(server.media_background.fn(identifier=movie_id)) + assert isinstance(art, str) and "art" in art + item = json.loads(asyncio.run(server.media_item.fn(identifier=movie_id))) + assert item["plex"]["rating_key"] == movie_id + start = time.perf_counter() res = asyncio.run( server.search_media.fn(query="Matthew McConaughey crime movie", limit=1) @@ -150,12 +186,16 @@ def test_server_tools(tmp_path, monkeypatch): assert elapsed < 0.2 assert res and res[0]["plex"]["title"] == "The Gentlemen" - # Prefetched payloads should allow resource access without hitting the client + # _find_records should handle client retrieval errors gracefully orig_retrieve, orig_scroll = server._client.retrieve, server._client.scroll - - async def fail(*args, **kwargs): # pragma: no cover + async def fail(*args, **kwargs): raise AssertionError("client called") + server._client.retrieve = fail + asyncio.run(server._find_records("12345", limit=1)) + server._client.retrieve = orig_retrieve + + # Prefetched payloads should allow resource access without hitting the client server._client.retrieve = fail server._client.scroll = fail try: @@ -174,6 +214,28 @@ async def fail(*args, **kwargs): # pragma: no cover server._client.retrieve = orig_retrieve server._client.scroll = orig_scroll + with pytest.raises(AssertionError): + asyncio.run(fail()) + + monkeypatch.setattr(server, "_CACHE_SIZE", 1) + server._cache_set(server._poster_cache, "a", "1") + server._cache_set(server._poster_cache, "b", "2") + + # Reranking should reorder results based on cross-encoder scores + orig_search = server._client.search + async def fake_search(*args, **kwargs): + return [ + models.ScoredPoint(id=1, version=1, score=0.0, payload={"data": {"title": "A"}}, vector=None), + models.ScoredPoint(id=2, version=1, score=0.0, payload={"data": {"title": "B"}}, vector=None), + models.ScoredPoint(id=3, version=1, score=0.0, payload={"data": {"title": "C"}}, vector=None), + ] + server._client.search = fake_search + try: + res = asyncio.run(server.search_media.fn(query="test", limit=2)) + assert [i["title"] for i in res] == ["C", "B"] + finally: + server._client.search = orig_search + res = asyncio.run(server.recommend_media.fn(identifier=movie_id, limit=1)) assert res and res[0]["plex"]["rating_key"] == "61960" @@ -194,3 +256,65 @@ async def fail(*args, **kwargs): # pragma: no cover with pytest.raises(ValueError): asyncio.run(server.media_background.fn(identifier="0")) + + +def _patch_dependencies(monkeypatch): + monkeypatch.setattr(loader, "TextEmbedding", DummyTextEmbedding) + monkeypatch.setattr(loader, "SparseTextEmbedding", DummySparseEmbedding) + monkeypatch.setattr(loader, "AsyncQdrantClient", DummyQdrantClient) + import fastembed + from qdrant_client import async_qdrant_client + monkeypatch.setattr(fastembed, "TextEmbedding", DummyTextEmbedding) + monkeypatch.setattr(fastembed, "SparseTextEmbedding", DummySparseEmbedding) + monkeypatch.setattr(async_qdrant_client, "AsyncQdrantClient", DummyQdrantClient) + + +def test_reranker_import_failure(monkeypatch): + _patch_dependencies(monkeypatch) + monkeypatch.delitem(sys.modules, "sentence_transformers", raising=False) + import builtins + orig_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == "sentence_transformers": + raise ImportError + return orig_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + monkeypatch.setenv("USE_RERANKER", "1") + server = importlib.reload(importlib.import_module("mcp_plex.server")) + + async def fake_search(*args, **kwargs): + return [ + models.ScoredPoint(id=1, version=1, score=0.0, payload={"data": {"title": "A", "plex": {"rating_key": 1}}}, vector=None), + models.ScoredPoint(id=2, version=1, score=0.0, payload={"data": {"title": "B", "plex": {"rating_key": 2}}}, vector=None), + ] + + server._client.search = fake_search + res = asyncio.run(server.search_media.fn(query="test", limit=2)) + assert [i["title"] for i in res] == ["A", "B"] + + +def test_reranker_init_failure(monkeypatch): + _patch_dependencies(monkeypatch) + monkeypatch.setenv("USE_RERANKER", "1") + st_module = types.ModuleType("sentence_transformers") + + class Broken: + def __init__(self, *args, **kwargs): + raise RuntimeError("boom") + + st_module.CrossEncoder = Broken + monkeypatch.setitem(sys.modules, "sentence_transformers", st_module) + server = importlib.reload(importlib.import_module("mcp_plex.server")) + assert server._reranker is None + + async def fake_search(*args, **kwargs): + return [ + models.ScoredPoint(id=1, version=1, score=0.0, payload={"data": {"title": "A", "plex": {"rating_key": 1}}}, vector=None), + models.ScoredPoint(id=2, version=1, score=0.0, payload={"data": {"title": "B", "plex": {"rating_key": 2}}}, vector=None), + ] + + server._client.search = fake_search + res = asyncio.run(server.search_media.fn(query="test", limit=2)) + assert [i["title"] for i in res] == ["A", "B"] From e9da33ccbf8e589d932829336310bde0ff6d1cc5 Mon Sep 17 00:00:00 2001 From: Teagan Glenn Date: Sun, 31 Aug 2025 01:04:13 -0600 Subject: [PATCH 2/3] style: sort imports for ruff --- mcp_plex/loader.py | 8 ++++---- mcp_plex/server.py | 10 +++++----- mcp_plex/types.py | 2 +- tests/test_load_from_plex.py | 1 + tests/test_loader_cli.py | 3 ++- tests/test_loader_integration.py | 3 ++- tests/test_loader_unit.py | 12 ++++++------ tests/test_server.py | 11 ++++++----- 8 files changed, 27 insertions(+), 23 deletions(-) diff --git a/mcp_plex/loader.py b/mcp_plex/loader.py index 3d52d7b..6a7275b 100644 --- a/mcp_plex/loader.py +++ b/mcp_plex/loader.py @@ -10,25 +10,25 @@ import click import httpx from fastembed import SparseTextEmbedding, TextEmbedding -from qdrant_client.async_qdrant_client import AsyncQdrantClient from qdrant_client import models +from qdrant_client.async_qdrant_client import AsyncQdrantClient from .types import ( AggregatedItem, ExternalIDs, IMDbTitle, + PlexGuid, PlexItem, + PlexPerson, TMDBEpisode, TMDBItem, TMDBMovie, TMDBShow, - PlexGuid, - PlexPerson, ) try: # Only import plexapi when available; the sample data mode does not require it. - from plexapi.server import PlexServer from plexapi.base import PlexPartialObject + from plexapi.server import PlexServer except Exception: PlexServer = None # type: ignore[assignment] PlexPartialObject = object # type: ignore[assignment] diff --git a/mcp_plex/server.py b/mcp_plex/server.py index 7767c2a..642e2fd 100644 --- a/mcp_plex/server.py +++ b/mcp_plex/server.py @@ -2,16 +2,16 @@ from __future__ import annotations import asyncio -import os import json +import os from collections import OrderedDict -from typing import Any, Annotated +from typing import Annotated, Any +from fastembed import SparseTextEmbedding, TextEmbedding from fastmcp.server import FastMCP -from qdrant_client.async_qdrant_client import AsyncQdrantClient -from qdrant_client import models -from fastembed import TextEmbedding, SparseTextEmbedding from pydantic import Field +from qdrant_client import models +from qdrant_client.async_qdrant_client import AsyncQdrantClient try: from sentence_transformers import CrossEncoder diff --git a/mcp_plex/types.py b/mcp_plex/types.py index 28040f5..532f05d 100644 --- a/mcp_plex/types.py +++ b/mcp_plex/types.py @@ -2,7 +2,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import List, Optional, Literal +from typing import List, Literal, Optional from pydantic import BaseModel, Field diff --git a/tests/test_load_from_plex.py b/tests/test_load_from_plex.py index 44c4c81..d7b757b 100644 --- a/tests/test_load_from_plex.py +++ b/tests/test_load_from_plex.py @@ -1,5 +1,6 @@ import asyncio import types + import httpx from mcp_plex import loader diff --git a/tests/test_loader_cli.py b/tests/test_loader_cli.py index 80d6249..da46024 100644 --- a/tests/test_loader_cli.py +++ b/tests/test_loader_cli.py @@ -1,6 +1,7 @@ import asyncio -from click.testing import CliRunner + import pytest +from click.testing import CliRunner from mcp_plex import loader diff --git a/tests/test_loader_integration.py b/tests/test_loader_integration.py index d75b34c..c2b97ee 100644 --- a/tests/test_loader_integration.py +++ b/tests/test_loader_integration.py @@ -2,9 +2,10 @@ from pathlib import Path from types import SimpleNamespace -from mcp_plex import loader from qdrant_client import models +from mcp_plex import loader + class DummyTextEmbedding: def __init__(self, name: str): diff --git a/tests/test_loader_unit.py b/tests/test_loader_unit.py index eeb3a03..c6fe106 100644 --- a/tests/test_loader_unit.py +++ b/tests/test_loader_unit.py @@ -1,17 +1,17 @@ -import types import asyncio -import httpx - +import types from pathlib import Path +import httpx + from mcp_plex.loader import ( - _extract_external_ids, - _load_from_sample, _build_plex_item, + _extract_external_ids, _fetch_imdb, + _fetch_tmdb_episode, _fetch_tmdb_movie, _fetch_tmdb_show, - _fetch_tmdb_episode, + _load_from_sample, ) diff --git a/tests/test_server.py b/tests/test_server.py index c36ee17..f93e717 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,15 +1,16 @@ -from typing import Any import asyncio -from pathlib import Path import importlib -import types import json -import time import sys +import time +import types +from pathlib import Path +from typing import Any + import pytest +from qdrant_client import models from mcp_plex import loader -from qdrant_client import models class DummyTextEmbedding: From d49fcd48815ed84a2f8d1fab3a455056c169c403 Mon Sep 17 00:00:00 2001 From: Teagan Glenn Date: Sun, 31 Aug 2025 01:06:40 -0600 Subject: [PATCH 3/3] Fix linting --- tests/test_server.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/tests/test_server.py b/tests/test_server.py index 0429596..f93e717 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -7,7 +7,6 @@ from pathlib import Path from typing import Any -import sys import pytest from qdrant_client import models @@ -74,24 +73,6 @@ def predict(self, pairs): return scores -class DummyReranker: - def __init__(self, model: str): - pass - - def predict(self, pairs): - scores = [] - for _, doc in pairs: - if "Gentlemen" in doc: - scores.append(10) - elif "C" in doc: - scores.append(2) - elif "B" in doc: - scores.append(1) - else: - scores.append(0) - return scores - - class DummyQdrantClient: store: dict[Any, models.Record] = {} size: int = 3