Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 74 additions & 13 deletions mcp_plex/server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""FastMCP server exposing Plex metadata tools."""
from __future__ import annotations

import asyncio
import os
import json
from collections import OrderedDict
from typing import Any, Annotated

from fastmcp.server import FastMCP
Expand All @@ -23,6 +25,27 @@
server = FastMCP()


_CACHE_SIZE = 128
_payload_cache: OrderedDict[str, dict[str, Any]] = OrderedDict()
_poster_cache: OrderedDict[str, str] = OrderedDict()
_background_cache: OrderedDict[str, str] = OrderedDict()


def _cache_set(cache: OrderedDict, key: str, value: Any) -> None:
if key in cache:
cache.move_to_end(key)
cache[key] = value
while len(cache) > _CACHE_SIZE:
cache.popitem(last=False)


def _cache_get(cache: OrderedDict, key: str) -> Any | None:
if key in cache:
cache.move_to_end(key)
return cache[key]
return None


async def _find_records(identifier: str, limit: int = 5) -> list[models.Record]:
"""Locate records matching an identifier or title."""
# First, try direct ID lookup
Expand Down Expand Up @@ -60,10 +83,23 @@ async def _find_records(identifier: str, limit: int = 5) -> list[models.Record]:

async def _get_media_data(identifier: str) -> dict[str, Any]:
"""Return the first matching media record's payload."""
cached = _cache_get(_payload_cache, identifier)
if cached is not None:
return cached
records = await _find_records(identifier, limit=1)
if not records:
raise ValueError("Media item not found")
return records[0].payload["data"]
data = records[0].payload["data"]
rating_key = str(data.get("plex", {}).get("rating_key"))
if rating_key:
_cache_set(_payload_cache, rating_key, data)
thumb = data.get("plex", {}).get("thumb")
if thumb:
_cache_set(_poster_cache, rating_key, thumb)
art = data.get("plex", {}).get("art")
if art:
_cache_set(_background_cache, rating_key, art)
return data


@server.tool("get-media")
Expand Down Expand Up @@ -101,8 +137,9 @@ async def search_media(
] = 5,
) -> list[dict[str, Any]]:
"""Hybrid similarity search across media items using dense and sparse vectors."""
dense_vec = list(_dense_model.embed([query]))[0]
sparse_vec = _sparse_model.query_embed(query)
dense_task = asyncio.to_thread(lambda: list(_dense_model.embed([query]))[0])
sparse_task = asyncio.to_thread(lambda: _sparse_model.query_embed(query))
dense_vec, sparse_vec = await asyncio.gather(dense_task, sparse_task)
named_dense = models.NamedVector(name="dense", vector=dense_vec)
sv = models.SparseVector(
indices=sparse_vec.indices.tolist(), values=sparse_vec.values.tolist()
Expand All @@ -112,10 +149,30 @@ async def search_media(
collection_name="media-items",
query_vector=named_dense,
query_sparse_vector=named_sparse,
limit=limit,
limit=limit * 3,
with_payload=True,
)
return [h.payload["data"] for h in hits]

async def _prefetch(hit: models.ScoredPoint) -> None:
data = hit.payload["data"]
rating_key = str(data.get("plex", {}).get("rating_key"))
if rating_key:
_cache_set(_payload_cache, rating_key, data)
thumb = data.get("plex", {}).get("thumb")
if thumb:
_cache_set(_poster_cache, rating_key, thumb)
art = data.get("plex", {}).get("art")
if art:
_cache_set(_background_cache, rating_key, art)

prefetch_task = asyncio.gather(*[_prefetch(h) for h in hits[:limit]])

def _rerank(hits: list[models.ScoredPoint]) -> list[models.ScoredPoint]:
return hits

reranked = await asyncio.to_thread(_rerank, hits)
await prefetch_task
return [h.payload["data"] for h in reranked[:limit]]


@server.tool("recommend-media")
Expand Down Expand Up @@ -200,12 +257,14 @@ async def media_poster(
],
) -> str:
"""Return the poster image URL for the given media identifier."""
records = await _find_records(identifier, limit=1)
if not records:
raise ValueError("Media item not found")
thumb = records[0].payload["data"].get("plex", {}).get("thumb")
cached = _cache_get(_poster_cache, identifier)
if cached:
return cached
data = await _get_media_data(identifier)
thumb = data.get("plex", {}).get("thumb")
if not thumb:
raise ValueError("Poster not available")
_cache_set(_poster_cache, str(data.get("plex", {}).get("rating_key")), thumb)
return thumb


Expand All @@ -220,12 +279,14 @@ async def media_background(
],
) -> str:
"""Return the background art URL for the given media identifier."""
records = await _find_records(identifier, limit=1)
if not records:
raise ValueError("Media item not found")
art = records[0].payload["data"].get("plex", {}).get("art")
cached = _cache_get(_background_cache, identifier)
if cached:
return cached
data = await _get_media_data(identifier)
art = data.get("plex", {}).get("art")
if not art:
raise ValueError("Background not available")
_cache_set(_background_cache, str(data.get("plex", {}).get("rating_key")), art)
return art


Expand Down
42 changes: 30 additions & 12 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import importlib
import types
import json
import time
import pytest

from mcp_plex import loader
Expand All @@ -20,6 +21,7 @@ def list_supported_models():

def embed(self, texts):
for _ in texts:
time.sleep(0.1)
yield [0.1, 0.2, 0.3]


Expand Down Expand Up @@ -47,6 +49,7 @@ def passage_embed(self, texts):
yield DummySparseVector([i], [1.0])

def query_embed(self, text):
time.sleep(0.1)
return DummySparseVector([0], [1.0])


Expand Down Expand Up @@ -139,11 +142,38 @@ def test_server_tools(tmp_path, monkeypatch):
res = asyncio.run(server.get_media.fn(identifier="The Gentlemen"))
assert res and res[0]["plex"]["rating_key"] == movie_id

start = time.perf_counter()
res = asyncio.run(
server.search_media.fn(query="Matthew McConaughey crime movie", limit=1)
)
elapsed = time.perf_counter() - start
assert elapsed < 0.2
assert res and res[0]["plex"]["title"] == "The Gentlemen"

# Prefetched payloads should allow resource access without hitting the client
orig_retrieve, orig_scroll = server._client.retrieve, server._client.scroll

async def fail(*args, **kwargs): # pragma: no cover
raise AssertionError("client called")

server._client.retrieve = fail
server._client.scroll = fail
try:
poster = asyncio.run(server.media_poster.fn(identifier=movie_id))
assert isinstance(poster, str) and "thumb" in poster

art = asyncio.run(server.media_background.fn(identifier=movie_id))
assert isinstance(art, str) and "art" in art

item = json.loads(asyncio.run(server.media_item.fn(identifier=movie_id)))
assert item["plex"]["rating_key"] == movie_id

ids = json.loads(asyncio.run(server.media_ids.fn(identifier=movie_id)))
assert ids["imdb"] == "tt8367814"
finally:
server._client.retrieve = orig_retrieve
server._client.scroll = orig_scroll

res = asyncio.run(server.recommend_media.fn(identifier=movie_id, limit=1))
assert res and res[0]["plex"]["rating_key"] == "61960"

Expand All @@ -154,18 +184,6 @@ def test_server_tools(tmp_path, monkeypatch):
# Exercise search path with an ID that doesn't exist
asyncio.run(server._find_records("12345", limit=1))

poster = asyncio.run(server.media_poster.fn(identifier=movie_id))
assert isinstance(poster, str) and "thumb" in poster

art = asyncio.run(server.media_background.fn(identifier=movie_id))
assert isinstance(art, str) and "art" in art

item = json.loads(asyncio.run(server.media_item.fn(identifier=movie_id)))
assert item["plex"]["rating_key"] == movie_id

ids = json.loads(asyncio.run(server.media_ids.fn(identifier=movie_id)))
assert ids["imdb"] == "tt8367814"

with pytest.raises(ValueError):
asyncio.run(server.media_item.fn(identifier="0"))
with pytest.raises(ValueError):
Expand Down