From d131fb565f28e00d7c32586c28a6b4c1dabe9a8d Mon Sep 17 00:00:00 2001 From: Teagan Glenn Date: Sun, 31 Aug 2025 00:05:22 -0600 Subject: [PATCH] feat(server): prefetch and cache media resources --- mcp_plex/server.py | 87 +++++++++++++++++++++++++++++++++++++------- tests/test_server.py | 42 +++++++++++++++------ 2 files changed, 104 insertions(+), 25 deletions(-) diff --git a/mcp_plex/server.py b/mcp_plex/server.py index 038ef2e..9a0da73 100644 --- a/mcp_plex/server.py +++ b/mcp_plex/server.py @@ -1,8 +1,10 @@ """FastMCP server exposing Plex metadata tools.""" from __future__ import annotations +import asyncio import os import json +from collections import OrderedDict from typing import Any, Annotated from fastmcp.server import FastMCP @@ -23,6 +25,27 @@ server = FastMCP() +_CACHE_SIZE = 128 +_payload_cache: OrderedDict[str, dict[str, Any]] = OrderedDict() +_poster_cache: OrderedDict[str, str] = OrderedDict() +_background_cache: OrderedDict[str, str] = OrderedDict() + + +def _cache_set(cache: OrderedDict, key: str, value: Any) -> None: + if key in cache: + cache.move_to_end(key) + cache[key] = value + while len(cache) > _CACHE_SIZE: + cache.popitem(last=False) + + +def _cache_get(cache: OrderedDict, key: str) -> Any | None: + if key in cache: + cache.move_to_end(key) + return cache[key] + return None + + async def _find_records(identifier: str, limit: int = 5) -> list[models.Record]: """Locate records matching an identifier or title.""" # First, try direct ID lookup @@ -60,10 +83,23 @@ async def _find_records(identifier: str, limit: int = 5) -> list[models.Record]: async def _get_media_data(identifier: str) -> dict[str, Any]: """Return the first matching media record's payload.""" + cached = _cache_get(_payload_cache, identifier) + if cached is not None: + return cached records = await _find_records(identifier, limit=1) if not records: raise ValueError("Media item not found") - return records[0].payload["data"] + data = records[0].payload["data"] + rating_key = str(data.get("plex", {}).get("rating_key")) + if rating_key: + _cache_set(_payload_cache, rating_key, data) + thumb = data.get("plex", {}).get("thumb") + if thumb: + _cache_set(_poster_cache, rating_key, thumb) + art = data.get("plex", {}).get("art") + if art: + _cache_set(_background_cache, rating_key, art) + return data @server.tool("get-media") @@ -101,8 +137,9 @@ async def search_media( ] = 5, ) -> list[dict[str, Any]]: """Hybrid similarity search across media items using dense and sparse vectors.""" - dense_vec = list(_dense_model.embed([query]))[0] - sparse_vec = _sparse_model.query_embed(query) + dense_task = asyncio.to_thread(lambda: list(_dense_model.embed([query]))[0]) + sparse_task = asyncio.to_thread(lambda: _sparse_model.query_embed(query)) + dense_vec, sparse_vec = await asyncio.gather(dense_task, sparse_task) named_dense = models.NamedVector(name="dense", vector=dense_vec) sv = models.SparseVector( indices=sparse_vec.indices.tolist(), values=sparse_vec.values.tolist() @@ -112,10 +149,30 @@ async def search_media( collection_name="media-items", query_vector=named_dense, query_sparse_vector=named_sparse, - limit=limit, + limit=limit * 3, with_payload=True, ) - return [h.payload["data"] for h in hits] + + async def _prefetch(hit: models.ScoredPoint) -> None: + data = hit.payload["data"] + rating_key = str(data.get("plex", {}).get("rating_key")) + if rating_key: + _cache_set(_payload_cache, rating_key, data) + thumb = data.get("plex", {}).get("thumb") + if thumb: + _cache_set(_poster_cache, rating_key, thumb) + art = data.get("plex", {}).get("art") + if art: + _cache_set(_background_cache, rating_key, art) + + prefetch_task = asyncio.gather(*[_prefetch(h) for h in hits[:limit]]) + + def _rerank(hits: list[models.ScoredPoint]) -> list[models.ScoredPoint]: + return hits + + reranked = await asyncio.to_thread(_rerank, hits) + await prefetch_task + return [h.payload["data"] for h in reranked[:limit]] @server.tool("recommend-media") @@ -200,12 +257,14 @@ async def media_poster( ], ) -> str: """Return the poster image URL for the given media identifier.""" - records = await _find_records(identifier, limit=1) - if not records: - raise ValueError("Media item not found") - thumb = records[0].payload["data"].get("plex", {}).get("thumb") + cached = _cache_get(_poster_cache, identifier) + if cached: + return cached + data = await _get_media_data(identifier) + thumb = data.get("plex", {}).get("thumb") if not thumb: raise ValueError("Poster not available") + _cache_set(_poster_cache, str(data.get("plex", {}).get("rating_key")), thumb) return thumb @@ -220,12 +279,14 @@ async def media_background( ], ) -> str: """Return the background art URL for the given media identifier.""" - records = await _find_records(identifier, limit=1) - if not records: - raise ValueError("Media item not found") - art = records[0].payload["data"].get("plex", {}).get("art") + cached = _cache_get(_background_cache, identifier) + if cached: + return cached + data = await _get_media_data(identifier) + art = data.get("plex", {}).get("art") if not art: raise ValueError("Background not available") + _cache_set(_background_cache, str(data.get("plex", {}).get("rating_key")), art) return art diff --git a/tests/test_server.py b/tests/test_server.py index 4dd64ef..6594af0 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -4,6 +4,7 @@ import importlib import types import json +import time import pytest from mcp_plex import loader @@ -20,6 +21,7 @@ def list_supported_models(): def embed(self, texts): for _ in texts: + time.sleep(0.1) yield [0.1, 0.2, 0.3] @@ -47,6 +49,7 @@ def passage_embed(self, texts): yield DummySparseVector([i], [1.0]) def query_embed(self, text): + time.sleep(0.1) return DummySparseVector([0], [1.0]) @@ -139,11 +142,38 @@ def test_server_tools(tmp_path, monkeypatch): res = asyncio.run(server.get_media.fn(identifier="The Gentlemen")) assert res and res[0]["plex"]["rating_key"] == movie_id + start = time.perf_counter() res = asyncio.run( server.search_media.fn(query="Matthew McConaughey crime movie", limit=1) ) + elapsed = time.perf_counter() - start + assert elapsed < 0.2 assert res and res[0]["plex"]["title"] == "The Gentlemen" + # Prefetched payloads should allow resource access without hitting the client + orig_retrieve, orig_scroll = server._client.retrieve, server._client.scroll + + async def fail(*args, **kwargs): # pragma: no cover + raise AssertionError("client called") + + server._client.retrieve = fail + server._client.scroll = fail + try: + poster = asyncio.run(server.media_poster.fn(identifier=movie_id)) + assert isinstance(poster, str) and "thumb" in poster + + art = asyncio.run(server.media_background.fn(identifier=movie_id)) + assert isinstance(art, str) and "art" in art + + item = json.loads(asyncio.run(server.media_item.fn(identifier=movie_id))) + assert item["plex"]["rating_key"] == movie_id + + ids = json.loads(asyncio.run(server.media_ids.fn(identifier=movie_id))) + assert ids["imdb"] == "tt8367814" + finally: + server._client.retrieve = orig_retrieve + server._client.scroll = orig_scroll + res = asyncio.run(server.recommend_media.fn(identifier=movie_id, limit=1)) assert res and res[0]["plex"]["rating_key"] == "61960" @@ -154,18 +184,6 @@ def test_server_tools(tmp_path, monkeypatch): # Exercise search path with an ID that doesn't exist asyncio.run(server._find_records("12345", limit=1)) - poster = asyncio.run(server.media_poster.fn(identifier=movie_id)) - assert isinstance(poster, str) and "thumb" in poster - - art = asyncio.run(server.media_background.fn(identifier=movie_id)) - assert isinstance(art, str) and "art" in art - - item = json.loads(asyncio.run(server.media_item.fn(identifier=movie_id))) - assert item["plex"]["rating_key"] == movie_id - - ids = json.loads(asyncio.run(server.media_ids.fn(identifier=movie_id))) - assert ids["imdb"] == "tt8367814" - with pytest.raises(ValueError): asyncio.run(server.media_item.fn(identifier="0")) with pytest.raises(ValueError):