From 976ddd675403feaa7ab390c99f59889258c4bfe3 Mon Sep 17 00:00:00 2001 From: Teagan Glenn Date: Mon, 6 Oct 2025 22:31:49 -0600 Subject: [PATCH] feat(common): tighten cache payload typing --- docker/pyproject.deps.toml | 2 +- mcp_plex/common/AGENTS.md | 1 + mcp_plex/common/__init__.py | 3 +- mcp_plex/common/cache.py | 22 +++++++--- mcp_plex/common/types.py | 7 +++- mcp_plex/common/validation.py | 8 +++- mcp_plex/loader/__init__.py | 3 +- mcp_plex/loader/imdb_cache.py | 7 +--- mcp_plex/server/__init__.py | 72 ++++++++++++++++++++++++--------- pyproject.toml | 2 +- tests/test_common_validation.py | 32 ++++++++++----- uv.lock | 2 +- 12 files changed, 111 insertions(+), 50 deletions(-) diff --git a/docker/pyproject.deps.toml b/docker/pyproject.deps.toml index 5829ba9..8a5de07 100644 --- a/docker/pyproject.deps.toml +++ b/docker/pyproject.deps.toml @@ -1,6 +1,6 @@ [project] name = "mcp-plex" -version = "1.0.18" +version = "1.0.19" requires-python = ">=3.11,<3.13" dependencies = [ "fastmcp>=2.11.2", diff --git a/mcp_plex/common/AGENTS.md b/mcp_plex/common/AGENTS.md index b2e9df3..408b0c3 100644 --- a/mcp_plex/common/AGENTS.md +++ b/mcp_plex/common/AGENTS.md @@ -4,4 +4,5 @@ - `mcp_plex.common` provides shared cache helpers, data models, and utility types consumed by both the loader and the server packages. - Keep shared logic decoupled from CLI wiring so it can be imported safely by tests and other packages. - Update this module when adding reusable functionality to avoid duplicating code between the loader and server implementations. +- Use the `JSONValue` and related aliases in `types.py` when exchanging cached payloads or structured JSON-like data. Media caches and downstream consumers expect payload dictionaries to resolve to `dict[str, JSONValue]` without falling back to ``Any``. diff --git a/mcp_plex/common/__init__.py b/mcp_plex/common/__init__.py index db87bdd..afe2a85 100644 --- a/mcp_plex/common/__init__.py +++ b/mcp_plex/common/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations from .cache import MediaCache +from .types import JSONValue from .validation import require_positive -__all__ = ["MediaCache", "require_positive"] +__all__ = ["MediaCache", "JSONValue", "require_positive"] diff --git a/mcp_plex/common/cache.py b/mcp_plex/common/cache.py index e1f6e76..e0aa9b5 100644 --- a/mcp_plex/common/cache.py +++ b/mcp_plex/common/cache.py @@ -2,7 +2,13 @@ from __future__ import annotations from collections import OrderedDict -from typing import Any +from typing import TypeVar + +from .types import JSONValue + + +CachedPayload = dict[str, JSONValue] +_CacheValueT = TypeVar("_CacheValueT") class MediaCache: @@ -10,27 +16,31 @@ class MediaCache: def __init__(self, size: int = 128) -> None: self.size = size - self._payload: OrderedDict[str, dict[str, Any]] = OrderedDict() + self._payload: OrderedDict[str, CachedPayload] = OrderedDict() self._poster: OrderedDict[str, str] = OrderedDict() self._background: OrderedDict[str, str] = OrderedDict() - def _set(self, cache: OrderedDict, key: str, value: Any) -> None: + def _set( + self, cache: OrderedDict[str, _CacheValueT], key: str, value: _CacheValueT + ) -> None: if key in cache: cache.move_to_end(key) cache[key] = value while len(cache) > self.size: cache.popitem(last=False) - def _get(self, cache: OrderedDict, key: str) -> Any | None: + def _get( + self, cache: OrderedDict[str, _CacheValueT], key: str + ) -> _CacheValueT | None: if key in cache: cache.move_to_end(key) return cache[key] return None - def get_payload(self, key: str) -> dict[str, Any] | None: + def get_payload(self, key: str) -> CachedPayload | None: return self._get(self._payload, key) - def set_payload(self, key: str, value: dict[str, Any]) -> None: + def set_payload(self, key: str, value: CachedPayload) -> None: self._set(self._payload, key, value) def get_poster(self, key: str) -> str | None: diff --git a/mcp_plex/common/types.py b/mcp_plex/common/types.py index a344a93..baf0132 100644 --- a/mcp_plex/common/types.py +++ b/mcp_plex/common/types.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from datetime import datetime -from typing import List, Literal, Optional +from typing import List, Literal, Mapping, MutableMapping, Optional, Sequence, TypeAlias from pydantic import BaseModel, Field @@ -176,3 +176,8 @@ class ExternalIDs: "AggregatedItem", "ExternalIDs", ] +JSONScalar: TypeAlias = str | int | float | bool | None +JSONValue: TypeAlias = JSONScalar | Sequence["JSONValue"] | Mapping[str, "JSONValue"] +JSONMapping: TypeAlias = Mapping[str, JSONValue] +MutableJSONMapping: TypeAlias = MutableMapping[str, JSONValue] + diff --git a/mcp_plex/common/validation.py b/mcp_plex/common/validation.py index 3d17921..eb562b9 100644 --- a/mcp_plex/common/validation.py +++ b/mcp_plex/common/validation.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any +from typing import SupportsInt def require_positive(value: int, *, name: str) -> int: @@ -15,9 +15,13 @@ def require_positive(value: int, *, name: str) -> int: return value -def coerce_plex_tag_id(raw_id: Any) -> int: +def coerce_plex_tag_id(raw_id: int | str | SupportsInt | None) -> int: """Best-effort conversion of Plex media tag identifiers to integers.""" + if raw_id is None: + return 0 + if isinstance(raw_id, bool): + return int(raw_id) if isinstance(raw_id, int): return raw_id if isinstance(raw_id, str): diff --git a/mcp_plex/loader/__init__.py b/mcp_plex/loader/__init__.py index d44f8d0..2e862dc 100644 --- a/mcp_plex/loader/__init__.py +++ b/mcp_plex/loader/__init__.py @@ -19,7 +19,7 @@ from plexapi.base import PlexPartialObject as _PlexPartialObject from plexapi.server import PlexServer -from .imdb_cache import IMDbCache, JSONValue +from .imdb_cache import IMDbCache from .pipeline.channels import ( IMDbRetryQueue, INGEST_DONE, @@ -33,6 +33,7 @@ from ..common.types import ( AggregatedItem, IMDbTitle, + JSONValue, PlexGuid, PlexItem, PlexPerson, diff --git a/mcp_plex/loader/imdb_cache.py b/mcp_plex/loader/imdb_cache.py index f06b6a4..68dfb08 100644 --- a/mcp_plex/loader/imdb_cache.py +++ b/mcp_plex/loader/imdb_cache.py @@ -7,12 +7,7 @@ from pydantic import ValidationError -from ..common.types import IMDbTitle - -JSONScalar: TypeAlias = str | int | float | bool | None -JSONValue: TypeAlias = ( - JSONScalar | list["JSONValue"] | dict[str, "JSONValue"] -) +from ..common.types import IMDbTitle, JSONValue CachedIMDbPayload: TypeAlias = IMDbTitle | JSONValue diff --git a/mcp_plex/server/__init__.py b/mcp_plex/server/__init__.py index 65e635a..ece8115 100644 --- a/mcp_plex/server/__init__.py +++ b/mcp_plex/server/__init__.py @@ -9,7 +9,7 @@ import logging import os import uuid -from typing import Annotated, Any, Callable, Sequence +from typing import Annotated, Any, Callable, Mapping, Sequence, cast from fastapi import FastAPI from fastapi.openapi.docs import get_swagger_ui_html @@ -28,6 +28,7 @@ from rapidfuzz import fuzz, process from ..common.cache import MediaCache +from ..common.types import JSONValue from .config import Settings @@ -204,10 +205,15 @@ async def _find_records(identifier: str, limit: int = 5) -> list[models.Record]: return points -def _flatten_payload(payload: dict[str, Any]) -> dict[str, Any]: +def _flatten_payload(payload: Mapping[str, JSONValue] | None) -> dict[str, JSONValue]: """Merge top-level payload fields with the nested data block.""" - data = dict(payload.get("data", {})) + data: dict[str, JSONValue] = {} + if not payload: + return data + base = payload.get("data") + if isinstance(base, dict): + data.update(base) for key, value in payload.items(): if key == "data": continue @@ -215,7 +221,7 @@ def _flatten_payload(payload: dict[str, Any]) -> dict[str, Any]: return data -async def _get_media_data(identifier: str) -> dict[str, Any]: +async def _get_media_data(identifier: str) -> dict[str, JSONValue]: """Return the first matching media record's payload.""" cached = server.cache.get_payload(identifier) if cached is not None: @@ -223,10 +229,12 @@ async def _get_media_data(identifier: str) -> dict[str, Any]: records = await _find_records(identifier, limit=1) if not records: raise ValueError("Media item not found") - payload = _flatten_payload(records[0].payload) + payload = _flatten_payload( + cast(Mapping[str, JSONValue] | None, records[0].payload) + ) data = payload - def _normalize_identifier(value: Any) -> str | None: + def _normalize_identifier(value: JSONValue) -> str | None: if value is None: return None if isinstance(value, str): @@ -243,7 +251,10 @@ def _normalize_identifier(value: Any) -> str | None: if lookup_key: cache_keys.add(lookup_key) - plex_data = data.get("plex", {}) or {} + plex_value = data.get("plex") + plex_data: dict[str, JSONValue] = ( + plex_value if isinstance(plex_value, dict) else {} + ) rating_key = _normalize_identifier(plex_data.get("rating_key")) if rating_key: cache_keys.add(rating_key) @@ -252,9 +263,9 @@ def _normalize_identifier(value: Any) -> str | None: cache_keys.add(guid) for source_key in ("imdb", "tmdb", "tvdb"): - source_data = data.get(source_key) - if isinstance(source_data, dict): - source_id = _normalize_identifier(source_data.get("id")) + source_value = data.get(source_key) + if isinstance(source_value, dict): + source_id = _normalize_identifier(source_value.get("id")) if source_id: cache_keys.add(source_id) @@ -529,7 +540,10 @@ async def get_media( ) -> list[dict[str, Any]]: """Retrieve media items by rating key, IMDb/TMDb ID or title.""" records = await _find_records(identifier, limit=10) - return [_flatten_payload(r.payload) for r in records] + return [ + _flatten_payload(cast(Mapping[str, JSONValue] | None, r.payload)) + for r in records + ] @server.tool("search-media") @@ -578,7 +592,7 @@ async def search_media( hits = res.points async def _prefetch(hit: models.ScoredPoint) -> None: - data = _flatten_payload(hit.payload) + data = _flatten_payload(cast(Mapping[str, JSONValue] | None, hit.payload)) rating_key = str(data.get("plex", {}).get("rating_key")) if rating_key: server.cache.set_payload(rating_key, data) @@ -596,7 +610,9 @@ def _rerank(hits: list[models.ScoredPoint]) -> list[models.ScoredPoint]: return hits docs: list[str] = [] for h in hits: - data = _flatten_payload(h.payload) + data = _flatten_payload( + cast(Mapping[str, JSONValue] | None, h.payload) + ) parts = [ data.get("title"), data.get("summary"), @@ -648,7 +664,10 @@ def _join_people(values: Any) -> str: reranked = await asyncio.to_thread(_rerank, hits) await prefetch_task - return [_flatten_payload(h.payload) for h in reranked[:limit]] + return [ + _flatten_payload(cast(Mapping[str, JSONValue] | None, h.payload)) + for h in reranked[:limit] + ] @server.tool("query-media") @@ -940,7 +959,10 @@ def _listify(value: Sequence[str] | str | None) -> list[str]: limit=limit, with_payload=True, ) - return [_flatten_payload(p.payload) for p in res.points] + return [ + _flatten_payload(cast(Mapping[str, JSONValue] | None, p.payload)) + for p in res.points + ] @server.tool("recommend-media") @@ -979,7 +1001,10 @@ async def recommend_media( with_payload=True, using="dense", ) - return [_flatten_payload(r.payload) for r in response.points] + return [ + _flatten_payload(cast(Mapping[str, JSONValue] | None, r.payload)) + for r in response.points + ] @server.tool("new-movies") @@ -1012,7 +1037,10 @@ async def new_movies( limit=limit, with_payload=True, ) - return [_flatten_payload(p.payload) for p in res.points] + return [ + _flatten_payload(cast(Mapping[str, JSONValue] | None, p.payload)) + for p in res.points + ] @server.tool("new-shows") @@ -1045,7 +1073,10 @@ async def new_shows( limit=limit, with_payload=True, ) - return [_flatten_payload(p.payload) for p in res.points] + return [ + _flatten_payload(cast(Mapping[str, JSONValue] | None, p.payload)) + for p in res.points + ] @server.tool("actor-movies") @@ -1098,7 +1129,10 @@ async def actor_movies( limit=limit, with_payload=True, ) - return [_flatten_payload(p.payload) for p in res.points] + return [ + _flatten_payload(cast(Mapping[str, JSONValue] | None, p.payload)) + for p in res.points + ] @server.resource("resource://media-item/{identifier}") diff --git a/pyproject.toml b/pyproject.toml index 0110439..6ab0043 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mcp-plex" -version = "1.0.18" +version = "1.0.19" description = "Plex-Oriented Model Context Protocol Server" requires-python = ">=3.11,<3.13" diff --git a/tests/test_common_validation.py b/tests/test_common_validation.py index ed16068..985ffc8 100644 --- a/tests/test_common_validation.py +++ b/tests/test_common_validation.py @@ -21,14 +21,24 @@ def test_require_positive_enforces_int_type(bad_type: object) -> None: require_positive(bad_type, name="value") # type: ignore[arg-type] -def test_coerce_plex_tag_id_accepts_ints() -> None: - assert coerce_plex_tag_id(7) == 7 - - -def test_coerce_plex_tag_id_coerces_strings() -> None: - assert coerce_plex_tag_id(" 42 ") == 42 - - -def test_coerce_plex_tag_id_handles_invalid_values() -> None: - assert coerce_plex_tag_id(None) == 0 - assert coerce_plex_tag_id("not-a-number") == 0 +class _SupportsInt: + def __int__(self) -> int: + return 128 + + +@pytest.mark.parametrize( + "raw, expected", + [ + (7, 7), + (True, 1), + (" 42 ", 42), + (_SupportsInt(), 128), + ], +) +def test_coerce_plex_tag_id_normalizes_values(raw, expected) -> None: + assert coerce_plex_tag_id(raw) == expected + + +@pytest.mark.parametrize("raw", [None, "", "not-a-number"]) +def test_coerce_plex_tag_id_handles_invalid_values(raw) -> None: + assert coerce_plex_tag_id(raw) == 0 diff --git a/uv.lock b/uv.lock index d691f41..8e0e885 100644 --- a/uv.lock +++ b/uv.lock @@ -730,7 +730,7 @@ wheels = [ [[package]] name = "mcp-plex" -version = "1.0.18" +version = "1.0.19" source = { editable = "." } dependencies = [ { name = "fastapi" },