diff --git a/AGENTS.md b/AGENTS.md index 9fb6540..a7895ec 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -13,6 +13,8 @@ to combine dense and sparse results before optional cross-encoder reranking. - Qdrant client initialization moved into `PlexServer` to centralize state and simplify testing. +- Media payload and artwork caching centralized in `MediaCache` attached to + `PlexServer` to streamline cache management and testing. ## User Queries The project should handle natural-language searches and recommendations such as: diff --git a/mcp_plex/cache.py b/mcp_plex/cache.py new file mode 100644 index 0000000..e1f6e76 --- /dev/null +++ b/mcp_plex/cache.py @@ -0,0 +1,52 @@ +"""In-memory LRU cache for media payload and artwork data.""" +from __future__ import annotations + +from collections import OrderedDict +from typing import Any + + +class MediaCache: + """LRU caches for media payload, posters, and backgrounds.""" + + def __init__(self, size: int = 128) -> None: + self.size = size + self._payload: OrderedDict[str, dict[str, Any]] = OrderedDict() + self._poster: OrderedDict[str, str] = OrderedDict() + self._background: OrderedDict[str, str] = OrderedDict() + + def _set(self, cache: OrderedDict, key: str, value: Any) -> None: + if key in cache: + cache.move_to_end(key) + cache[key] = value + while len(cache) > self.size: + cache.popitem(last=False) + + def _get(self, cache: OrderedDict, key: str) -> Any | None: + if key in cache: + cache.move_to_end(key) + return cache[key] + return None + + def get_payload(self, key: str) -> dict[str, Any] | None: + return self._get(self._payload, key) + + def set_payload(self, key: str, value: dict[str, Any]) -> None: + self._set(self._payload, key, value) + + def get_poster(self, key: str) -> str | None: + return self._get(self._poster, key) + + def set_poster(self, key: str, value: str) -> None: + self._set(self._poster, key, value) + + def get_background(self, key: str) -> str | None: + return self._get(self._background, key) + + def set_background(self, key: str, value: str) -> None: + self._set(self._background, key, value) + + def clear(self) -> None: + """Remove all cached entries.""" + self._payload.clear() + self._poster.clear() + self._background.clear() diff --git a/mcp_plex/server.py b/mcp_plex/server.py index d2cf832..605d847 100644 --- a/mcp_plex/server.py +++ b/mcp_plex/server.py @@ -6,7 +6,6 @@ import inspect import json import os -from collections import OrderedDict from typing import Annotated, Any, Callable from fastapi import FastAPI @@ -21,6 +20,8 @@ from starlette.requests import Request from starlette.responses import JSONResponse, PlainTextResponse, Response +from .cache import MediaCache + try: from sentence_transformers import CrossEncoder except Exception: @@ -51,6 +52,9 @@ _QDRANT_URL = ":memory:" +_CACHE_SIZE = 128 + + class PlexServer(FastMCP): """FastMCP server with an attached Qdrant client.""" @@ -65,6 +69,7 @@ def __init__(self) -> None: # noqa: D401 - short description inherited prefer_grpc=_QDRANT_PREFER_GRPC, https=_QDRANT_HTTPS, ) + self.cache = MediaCache(_CACHE_SIZE) _USE_RERANKER = os.getenv("USE_RERANKER", "1") == "1" _reranker = None @@ -77,27 +82,6 @@ def __init__(self) -> None: # noqa: D401 - short description inherited server = PlexServer() -_CACHE_SIZE = 128 -_payload_cache: OrderedDict[str, dict[str, Any]] = OrderedDict() -_poster_cache: OrderedDict[str, str] = OrderedDict() -_background_cache: OrderedDict[str, str] = OrderedDict() - - -def _cache_set(cache: OrderedDict, key: str, value: Any) -> None: - if key in cache: - cache.move_to_end(key) - cache[key] = value - while len(cache) > _CACHE_SIZE: - cache.popitem(last=False) - - -def _cache_get(cache: OrderedDict, key: str) -> Any | None: - if key in cache: - cache.move_to_end(key) - return cache[key] - return None - - async def _find_records(identifier: str, limit: int = 5) -> list[models.Record]: """Locate records matching an identifier or title.""" # First, try direct ID lookup @@ -137,7 +121,7 @@ async def _find_records(identifier: str, limit: int = 5) -> list[models.Record]: async def _get_media_data(identifier: str) -> dict[str, Any]: """Return the first matching media record's payload.""" - cached = _cache_get(_payload_cache, identifier) + cached = server.cache.get_payload(identifier) if cached is not None: return cached records = await _find_records(identifier, limit=1) @@ -146,13 +130,13 @@ async def _get_media_data(identifier: str) -> dict[str, Any]: data = records[0].payload["data"] rating_key = str(data.get("plex", {}).get("rating_key")) if rating_key: - _cache_set(_payload_cache, rating_key, data) + server.cache.set_payload(rating_key, data) thumb = data.get("plex", {}).get("thumb") if thumb: - _cache_set(_poster_cache, rating_key, thumb) + server.cache.set_poster(rating_key, thumb) art = data.get("plex", {}).get("art") if art: - _cache_set(_background_cache, rating_key, art) + server.cache.set_background(rating_key, art) return data @@ -219,13 +203,13 @@ async def _prefetch(hit: models.ScoredPoint) -> None: data = hit.payload["data"] rating_key = str(data.get("plex", {}).get("rating_key")) if rating_key: - _cache_set(_payload_cache, rating_key, data) + server.cache.set_payload(rating_key, data) thumb = data.get("plex", {}).get("thumb") if thumb: - _cache_set(_poster_cache, rating_key, thumb) + server.cache.set_poster(rating_key, thumb) art = data.get("plex", {}).get("art") if art: - _cache_set(_background_cache, rating_key, art) + server.cache.set_background(rating_key, art) prefetch_task = asyncio.gather(*[_prefetch(h) for h in hits[:limit]]) @@ -457,14 +441,16 @@ async def media_poster( ], ) -> str: """Return the poster image URL for the given media identifier.""" - cached = _cache_get(_poster_cache, identifier) + cached = server.cache.get_poster(identifier) if cached: return cached data = await _get_media_data(identifier) thumb = data.get("plex", {}).get("thumb") if not thumb: raise ValueError("Poster not available") - _cache_set(_poster_cache, str(data.get("plex", {}).get("rating_key")), thumb) + server.cache.set_poster( + str(data.get("plex", {}).get("rating_key")), thumb + ) return thumb @@ -479,14 +465,16 @@ async def media_background( ], ) -> str: """Return the background art URL for the given media identifier.""" - cached = _cache_get(_background_cache, identifier) + cached = server.cache.get_background(identifier) if cached: return cached data = await _get_media_data(identifier) art = data.get("plex", {}).get("art") if not art: raise ValueError("Background not available") - _cache_set(_background_cache, str(data.get("plex", {}).get("rating_key")), art) + server.cache.set_background( + str(data.get("plex", {}).get("rating_key")), art + ) return art diff --git a/pyproject.toml b/pyproject.toml index d976ac0..ae54d1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mcp-plex" -version = "0.26.10" +version = "0.26.11" description = "Plex-Oriented Model Context Protocol Server" requires-python = ">=3.11,<3.13" diff --git a/tests/test_server.py b/tests/test_server.py index ba518b1..c3ff03a 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -76,12 +76,17 @@ def test_server_tools(monkeypatch): poster = asyncio.run(server.media_poster.fn(identifier=movie_id)) assert isinstance(poster, str) and "thumb" in poster + assert server.server.cache.get_poster(movie_id) == poster art = asyncio.run(server.media_background.fn(identifier=movie_id)) assert isinstance(art, str) and "art" in art + assert server.server.cache.get_background(movie_id) == art item = json.loads(asyncio.run(server.media_item.fn(identifier=movie_id))) assert item["plex"]["rating_key"] == movie_id + assert ( + server.server.cache.get_payload(movie_id)["plex"]["rating_key"] == movie_id + ) ids = json.loads(asyncio.run(server.media_ids.fn(identifier=movie_id))) assert ids["imdb"] == "tt8367814" diff --git a/uv.lock b/uv.lock index 817dfa1..b7e6863 100644 --- a/uv.lock +++ b/uv.lock @@ -690,7 +690,7 @@ wheels = [ [[package]] name = "mcp-plex" -version = "0.26.10" +version = "0.26.11" source = { editable = "." } dependencies = [ { name = "fastapi" },