From be1d25f1050a9290377c23a815c5d302bb11025c Mon Sep 17 00:00:00 2001 From: Teagan Glenn Date: Mon, 6 Oct 2025 22:15:11 -0600 Subject: [PATCH 1/2] refactor: tighten loader typing --- mcp_plex/loader/AGENTS.md | 4 + mcp_plex/loader/__init__.py | 95 ++++++++++------ mcp_plex/loader/imdb_cache.py | 37 +++++- mcp_plex/loader/pipeline/__init__.py | 36 +----- mcp_plex/loader/pipeline/channels.py | 5 +- mcp_plex/loader/pipeline/enrichment.py | 136 +++++++++++++++-------- mcp_plex/loader/pipeline/ingestion.py | 8 +- mcp_plex/loader/pipeline/orchestrator.py | 25 ++++- mcp_plex/loader/pipeline/persistence.py | 21 +++- tests/test_enrichment_stage.py | 39 +++++-- 10 files changed, 261 insertions(+), 145 deletions(-) diff --git a/mcp_plex/loader/AGENTS.md b/mcp_plex/loader/AGENTS.md index ff696f3..25f8664 100644 --- a/mcp_plex/loader/AGENTS.md +++ b/mcp_plex/loader/AGENTS.md @@ -27,3 +27,7 @@ - `LoaderOrchestrator` must be initialised with the three stage instances, the ingest queue, the persistence queue, and the number of persistence workers (the CLI's `max_concurrent_upserts`). - Convert `AggregatedItem` batches into Qdrant `PointStruct` objects with `build_point` before handing them to the persistence stage's `enqueue_points` helper. - Prefer explicit keyword arguments when threading CLI options into stage constructors so the mapping is obvious to future readers. + +## Typing Guidelines +- Avoid introducing new ``Any`` or bare ``object`` annotations in loader modules. Use ``TypedDict`` definitions, ``Protocol`` classes, or precise unions instead. +- When wider typing is unavoidable, leave a brief comment explaining why the loosening is necessary so future contributors can revisit it. diff --git a/mcp_plex/loader/__init__.py b/mcp_plex/loader/__init__.py index 1f3a1bb..d44f8d0 100644 --- a/mcp_plex/loader/__init__.py +++ b/mcp_plex/loader/__init__.py @@ -9,7 +9,7 @@ import warnings from dataclasses import dataclass, field from pathlib import Path -from typing import Any, List, Optional, Sequence, TypeVar +from typing import TYPE_CHECKING, Sequence, TypedDict, TypeVar import click import httpx @@ -19,7 +19,7 @@ from plexapi.base import PlexPartialObject as _PlexPartialObject from plexapi.server import PlexServer -from .imdb_cache import IMDbCache +from .imdb_cache import IMDbCache, JSONValue from .pipeline.channels import ( IMDbRetryQueue, INGEST_DONE, @@ -50,6 +50,10 @@ category=RuntimeWarning, ) +if TYPE_CHECKING: # pragma: no cover - import for typing only + from .pipeline.enrichment import _RequestThrottler + + T = TypeVar("T") IMDB_BATCH_LIMIT: int = 5 @@ -70,9 +74,9 @@ class IMDbRuntimeConfig: retry_queue: IMDbRetryQueue requests_per_window: int | None window_seconds: float - _throttle: Any = field(default=None, init=False, repr=False) + _throttle: _RequestThrottler | None = field(default=None, init=False, repr=False) - def get_throttle(self) -> Any: + def get_throttle(self) -> _RequestThrottler | None: """Return the shared rate limiter, creating it on first use.""" if self.requests_per_window is None: @@ -132,7 +136,7 @@ async def _fetch_imdb( client: httpx.AsyncClient, imdb_id: str, config: IMDbRuntimeConfig, -) -> Optional[IMDbTitle]: +) -> IMDbTitle | None: """Fetch metadata for an IMDb ID with caching, retry, and throttling.""" from .pipeline import enrichment as enrichment_mod @@ -390,10 +394,35 @@ def _build_point_text(item: AggregatedItem) -> str: return "\n".join(p for p in parts if p) -def _build_point_payload(item: AggregatedItem) -> dict[str, object]: +class _BaseQdrantPayload(TypedDict): + data: dict[str, JSONValue] + title: str + type: str + + +class QdrantPayload(_BaseQdrantPayload, total=False): + show_title: str + season_title: str + season_number: int + episode_number: int + actors: list[str] + directors: list[str] + writers: list[str] + genres: list[str] + collections: list[str] + summary: str + overview: str + plot: str + tagline: str + reviews: list[str] + year: int + added_at: int + + +def _build_point_payload(item: AggregatedItem) -> QdrantPayload: """Construct the Qdrant payload for ``item``.""" - payload: dict[str, object] = { + payload: QdrantPayload = { "data": item.model_dump(mode="json"), "title": item.plex.title, "type": item.plex.type, @@ -470,10 +499,10 @@ def build_point( ) -def _load_from_sample(sample_dir: Path) -> List[AggregatedItem]: +def _load_from_sample(sample_dir: Path) -> list[AggregatedItem]: """Load items from local sample JSON files.""" - results: List[AggregatedItem] = [] + results: list[AggregatedItem] = [] movie_dir = sample_dir / "movie" episode_dir = sample_dir / "episode" @@ -713,13 +742,13 @@ def _record_upsert(worker_id: int, batch_size: int, queue_size: int) -> None: async def run( - plex_url: Optional[str], - plex_token: Optional[str], - tmdb_api_key: Optional[str], - sample_dir: Optional[Path], - qdrant_url: Optional[str], - qdrant_api_key: Optional[str], - qdrant_host: Optional[str] = None, + plex_url: str | None, + plex_token: str | None, + tmdb_api_key: str | None, + sample_dir: Path | None, + qdrant_url: str | None, + qdrant_api_key: str | None, + qdrant_host: str | None = None, qdrant_port: int = 6333, qdrant_grpc_port: int = 6334, qdrant_https: bool = False, @@ -799,7 +828,7 @@ async def run( dense_distance=dense_distance, ) - items: List[AggregatedItem] + items: list[AggregatedItem] if sample_dir is not None: logger.info("Loading sample data from %s", sample_dir) sample_items = _load_from_sample(sample_dir) @@ -1087,13 +1116,13 @@ async def run( help="Path to persistent IMDb retry queue", ) def main( - plex_url: Optional[str], - plex_token: Optional[str], - tmdb_api_key: Optional[str], - sample_dir: Optional[Path], - qdrant_url: Optional[str], - qdrant_api_key: Optional[str], - qdrant_host: Optional[str], + plex_url: str | None, + plex_token: str | None, + tmdb_api_key: str | None, + sample_dir: Path | None, + qdrant_url: str | None, + qdrant_api_key: str | None, + qdrant_host: str | None, qdrant_port: int, qdrant_grpc_port: int, qdrant_https: bool, @@ -1109,7 +1138,7 @@ def main( imdb_cache: Path, imdb_max_retries: int, imdb_backoff: float, - imdb_requests_per_window: Optional[int], + imdb_requests_per_window: int | None, imdb_window_seconds: float, imdb_queue: Path, log_level: str, @@ -1150,13 +1179,13 @@ def main( async def load_media( - plex_url: Optional[str], - plex_token: Optional[str], - tmdb_api_key: Optional[str], - sample_dir: Optional[Path], - qdrant_url: Optional[str], - qdrant_api_key: Optional[str], - qdrant_host: Optional[str], + plex_url: str | None, + plex_token: str | None, + tmdb_api_key: str | None, + sample_dir: Path | None, + qdrant_url: str | None, + qdrant_api_key: str | None, + qdrant_host: str | None, qdrant_port: int, qdrant_grpc_port: int, qdrant_https: bool, @@ -1168,7 +1197,7 @@ async def load_media( imdb_cache: Path, imdb_max_retries: int, imdb_backoff: float, - imdb_requests_per_window: Optional[int], + imdb_requests_per_window: int | None, imdb_window_seconds: float, imdb_queue: Path, upsert_buffer_size: int, diff --git a/mcp_plex/loader/imdb_cache.py b/mcp_plex/loader/imdb_cache.py index ece06ba..ad833b4 100644 --- a/mcp_plex/loader/imdb_cache.py +++ b/mcp_plex/loader/imdb_cache.py @@ -3,7 +3,17 @@ import json import logging from pathlib import Path -from typing import Any +from typing import TypeAlias, cast + +from ..common.types import IMDbTitle + +JSONScalar: TypeAlias = str | int | float | bool | None +JSONValue: TypeAlias = ( + JSONScalar | list["JSONValue"] | dict[str, "JSONValue"] +) + + +CachedIMDbPayload: TypeAlias = IMDbTitle | JSONValue class IMDbCache: @@ -13,7 +23,7 @@ class IMDbCache: def __init__(self, path: Path) -> None: self.path = path - self._data: dict[str, Any] = {} + self._data: dict[str, CachedIMDbPayload] = {} if path.exists(): try: raw_contents = path.read_text(encoding="utf-8") @@ -25,22 +35,37 @@ def __init__(self, path: Path) -> None: ) else: try: - self._data = json.loads(raw_contents) + loaded = json.loads(raw_contents) except (json.JSONDecodeError, UnicodeError) as exc: self._logger.warning( "Failed to decode IMDb cache JSON from %s; starting with empty cache.", path, exc_info=exc, ) + else: + if isinstance(loaded, dict): + self._data = { + str(key): cast(CachedIMDbPayload, value) + for key, value in loaded.items() + } + else: + self._logger.warning( + "IMDb cache at %s did not contain an object; ignoring its contents.", + path, + ) - def get(self, imdb_id: str) -> dict[str, Any] | None: + def get(self, imdb_id: str) -> CachedIMDbPayload | None: """Return cached data for ``imdb_id`` if present.""" return self._data.get(imdb_id) - def set(self, imdb_id: str, data: dict[str, Any]) -> None: + def set(self, imdb_id: str, data: CachedIMDbPayload) -> None: """Store ``data`` under ``imdb_id`` and persist to disk.""" self._data[imdb_id] = data self.path.parent.mkdir(parents=True, exist_ok=True) - self.path.write_text(json.dumps(self._data)) + serialisable = { + key: value.model_dump() if isinstance(value, IMDbTitle) else value + for key, value in self._data.items() + } + self.path.write_text(json.dumps(serialisable)) diff --git a/mcp_plex/loader/pipeline/__init__.py b/mcp_plex/loader/pipeline/__init__.py index cbf5d4b..51c1a6b 100644 --- a/mcp_plex/loader/pipeline/__init__.py +++ b/mcp_plex/loader/pipeline/__init__.py @@ -2,9 +2,6 @@ from __future__ import annotations -from importlib import import_module -from typing import TYPE_CHECKING, Any - from .channels import ( EpisodeBatch, IMDbRetryQueue, @@ -19,11 +16,10 @@ ) from ...common.validation import require_positive -if TYPE_CHECKING: - from .enrichment import EnrichmentStage - from .ingestion import IngestionStage - from .orchestrator import LoaderOrchestrator - from .persistence import PersistenceStage +from .enrichment import EnrichmentStage +from .ingestion import IngestionStage +from .orchestrator import LoaderOrchestrator +from .persistence import PersistenceStage __all__ = [ "IngestionStage", @@ -42,27 +38,3 @@ "chunk_sequence", "require_positive", ] - -_STAGE_MODULES = { - "IngestionStage": ".ingestion", - "EnrichmentStage": ".enrichment", - "PersistenceStage": ".persistence", - "LoaderOrchestrator": ".orchestrator", -} - - -def __getattr__(name: str) -> Any: - """Lazily import pipeline stage classes on first access.""" - - if name in _STAGE_MODULES: - module = import_module(f"{__name__}{_STAGE_MODULES[name]}") - value = getattr(module, name) - globals()[name] = value - return value - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - -def __dir__() -> list[str]: - """Return module attributes for introspection tools.""" - - return sorted(set(globals()) | set(__all__)) diff --git a/mcp_plex/loader/pipeline/channels.py b/mcp_plex/loader/pipeline/channels.py index 91f9069..85326d6 100644 --- a/mcp_plex/loader/pipeline/channels.py +++ b/mcp_plex/loader/pipeline/channels.py @@ -12,7 +12,6 @@ from dataclasses import dataclass from typing import ( TYPE_CHECKING, - Any, Final, Iterable, Literal, @@ -46,8 +45,8 @@ if TYPE_CHECKING: PersistencePayload: TypeAlias = list[models.PointStruct] -else: # pragma: no cover - runtime fallback for typing-only alias - PersistencePayload: TypeAlias = list[Any] + +PersistencePayload: TypeAlias = list["models.PointStruct"] @dataclass(slots=True) diff --git a/mcp_plex/loader/pipeline/enrichment.py b/mcp_plex/loader/pipeline/enrichment.py index 01043db..17a90ac 100644 --- a/mcp_plex/loader/pipeline/enrichment.py +++ b/mcp_plex/loader/pipeline/enrichment.py @@ -10,12 +10,16 @@ from __future__ import annotations import asyncio +import inspect import logging from collections import deque -from collections.abc import AsyncIterator, Awaitable, Callable, Sequence -from contextlib import asynccontextmanager -import inspect -from typing import Any, Optional +from collections.abc import AsyncIterator, Awaitable, Callable, Mapping, Sequence +from contextlib import ( + AbstractAsyncContextManager, + AbstractContextManager, + asynccontextmanager, +) +from typing import Protocol, cast import httpx from pydantic import ValidationError @@ -54,11 +58,38 @@ LOGGER = logging.getLogger(__name__) +class AsyncHTTPClient(Protocol): + """Minimal async HTTP client interface used by the enrichment stage.""" + + async def get( + self, + url: str, + *, + params: Mapping[str, str] | None = None, + headers: Mapping[str, str] | None = None, + ) -> httpx.Response: + ... + + async def aclose(self) -> None: + ... + + +HTTPClientResource = ( + AsyncHTTPClient + | AbstractAsyncContextManager[AsyncHTTPClient] + | AbstractContextManager[AsyncHTTPClient] +) + +HTTPClientFactory = Callable[ + [], HTTPClientResource | Awaitable[HTTPClientResource] +] + + def _extract_external_ids(item: PlexPartialObject) -> ExternalIDs: """Extract IMDb and TMDb IDs from a Plex object.""" - imdb_id: Optional[str] = None - tmdb_id: Optional[str] = None + imdb_id: str | None = None + tmdb_id: str | None = None for guid in getattr(item, "guids", []) or []: gid = getattr(guid, "id", "") if gid.startswith("imdb://"): @@ -140,8 +171,8 @@ def _build_plex_item(item: PlexPartialObject) -> PlexItem: async def _fetch_tmdb_movie( - client: httpx.AsyncClient, tmdb_id: str, api_key: str -) -> Optional[TMDBMovie]: + client: AsyncHTTPClient, tmdb_id: str, api_key: str +) -> TMDBMovie | None: url = f"https://api.themoviedb.org/3/movie/{tmdb_id}?append_to_response=reviews" try: resp = await client.get(url, headers={"Authorization": f"Bearer {api_key}"}) @@ -154,8 +185,8 @@ async def _fetch_tmdb_movie( async def _fetch_tmdb_show( - client: httpx.AsyncClient, tmdb_id: str, api_key: str -) -> Optional[TMDBShow]: + client: AsyncHTTPClient, tmdb_id: str, api_key: str +) -> TMDBShow | None: url = f"https://api.themoviedb.org/3/tv/{tmdb_id}?append_to_response=reviews" try: resp = await client.get(url, headers={"Authorization": f"Bearer {api_key}"}) @@ -168,12 +199,12 @@ async def _fetch_tmdb_show( async def _fetch_tmdb_episode( - client: httpx.AsyncClient, + client: AsyncHTTPClient, show_id: int, season_number: int, episode_number: int, api_key: str, -) -> Optional[TMDBEpisode]: +) -> TMDBEpisode | None: """Fetch TMDb data for a TV episode.""" url = ( @@ -195,7 +226,7 @@ async def _fetch_tmdb_episode( async def _fetch_tmdb_episode_chunk( - client: httpx.AsyncClient, + client: AsyncHTTPClient, show_id: int, append_paths: Sequence[str], api_key: str, @@ -236,8 +267,8 @@ async def _fetch_tmdb_episode_chunk( def resolve_tmdb_season_number( - show_tmdb: Optional[TMDBShow], episode: PlexPartialObject -) -> Optional[int]: + show_tmdb: TMDBShow | None, episode: PlexPartialObject +) -> int | None: """Map a Plex episode to the appropriate TMDb season number. This resolves cases where Plex uses year-based season indices that do not @@ -259,7 +290,7 @@ def resolve_tmdb_season_number( return season.season_number # match by season name (e.g. "Season 2018" -> "2018") - title_norm: Optional[str] = None + title_norm: str | None = None if isinstance(parent_title, str): title_norm = parent_title.lower().lstrip("season ").strip() for season in seasons: @@ -268,7 +299,7 @@ def resolve_tmdb_season_number( return season.season_number # match by air date year when Plex uses year-based seasons - year: Optional[int] = None + year: int | None = None if isinstance(parent_year, int): year = parent_year elif isinstance(parent_index, int): @@ -298,7 +329,7 @@ class EnrichmentStage: def __init__( self, *, - http_client_factory: Callable[[], Awaitable[Any] | Any], + http_client_factory: HTTPClientFactory, tmdb_api_key: str, ingest_queue: IngestQueue, persistence_queue: PersistenceQueue, @@ -313,7 +344,7 @@ def __init__( imdb_window_seconds: float = 1.0, logger: logging.Logger | None = None, ) -> None: - self._http_client_factory = http_client_factory + self._http_client_factory: HTTPClientFactory = http_client_factory self._tmdb_api_key = (tmdb_api_key or "").strip() self._ingest_queue = ingest_queue self._persistence_queue = persistence_queue @@ -473,37 +504,45 @@ async def _handle_episode_batch(self, batch: EpisodeBatch) -> None: ) @asynccontextmanager - async def _acquire_http_client(self) -> AsyncIterator[Any]: + async def _acquire_http_client(self) -> AsyncIterator[AsyncHTTPClient]: """Yield an HTTP client from the injected factory.""" resource = self._http_client_factory() if inspect.isawaitable(resource): resource = await resource - if hasattr(resource, "__aenter__") and hasattr(resource, "__aexit__"): + if isinstance(resource, AbstractAsyncContextManager): async with resource as client: yield client return - if hasattr(resource, "__enter__") and hasattr(resource, "__exit__"): + if isinstance(resource, AbstractContextManager): with resource as client: yield client return + if hasattr(resource, "__aenter__") and hasattr(resource, "__aexit__"): + async with cast(AbstractAsyncContextManager[AsyncHTTPClient], resource) as client: + yield client + return + + if hasattr(resource, "__enter__") and hasattr(resource, "__exit__"): + with cast(AbstractContextManager[AsyncHTTPClient], resource) as client: + yield client + return + + client = cast(AsyncHTTPClient, resource) try: - yield resource + yield client finally: - closer = getattr(resource, "aclose", None) - if callable(closer): - result = closer() - if inspect.isawaitable(result): - await result - return - closer = getattr(resource, "close", None) - if callable(closer): - result = closer() - if inspect.isawaitable(result): - await result + try: + await client.aclose() + except AttributeError: + closer = getattr(client, "close", None) + if callable(closer): + result = closer() + if inspect.isawaitable(result): + await result async def _emit_persistence_batch( self, aggregated: Sequence[AggregatedItem] @@ -521,7 +560,7 @@ async def _emit_persistence_batch( ) async def _enrich_movies( - self, client: Any, movies: Sequence[Any] + self, client: AsyncHTTPClient, movies: Sequence[PlexPartialObject] ) -> list[AggregatedItem]: """Fetch external metadata for *movies* and aggregate the results.""" @@ -543,7 +582,7 @@ async def _enrich_movies( ) api_key = self._tmdb_api_key - tmdb_tasks: list[asyncio.Task[Any]] = [] + tmdb_tasks: list[asyncio.Task[TMDBMovie | None]] = [] if api_key: for ids in movie_ids: if not ids.tmdb: @@ -556,11 +595,14 @@ async def _enrich_movies( imdb_map: dict[str, IMDbTitle | None] = {} retry_snapshot: set[str] = set() - tmdb_results: list[Any] = [] + tmdb_results: list[TMDBMovie | None] = [] if imdb_future is not None: combined_results = await asyncio.gather(imdb_future, *tmdb_tasks) - imdb_map = combined_results[0] - tmdb_results = list(combined_results[1:]) + imdb_map = cast(dict[str, IMDbTitle | None], combined_results[0]) + tmdb_results = [ + cast(TMDBMovie | None, result) + for result in combined_results[1:] + ] retry_snapshot = set(self._imdb_retry_queue.snapshot()) elif tmdb_tasks: tmdb_results = [await task for task in tmdb_tasks] @@ -596,7 +638,10 @@ async def _handle_sample_batch(self, batch: SampleBatch) -> None: ) async def _enrich_episodes( - self, client: Any, show: Any, episodes: Sequence[Any] + self, + client: AsyncHTTPClient, + show: PlexPartialObject, + episodes: Sequence[PlexPartialObject], ) -> list[AggregatedItem]: """Fetch external metadata for *episodes* and aggregate the results.""" @@ -665,7 +710,7 @@ async def _enrich_episodes( return aggregated async def _get_tmdb_show( - self, client: Any, tmdb_id: str + self, client: AsyncHTTPClient, tmdb_id: str ) -> TMDBShow | None: """Return the TMDb show for *tmdb_id*, using the in-memory cache.""" @@ -681,7 +726,10 @@ async def _get_tmdb_show( return show async def _bulk_lookup_tmdb_episodes( - self, client: Any, show_tmdb: TMDBShow, episodes: Sequence[Any] + self, + client: AsyncHTTPClient, + show_tmdb: TMDBShow, + episodes: Sequence[PlexPartialObject], ) -> list[TMDBEpisode | None]: """Fetch TMDb metadata for *episodes* in batches.""" @@ -834,7 +882,7 @@ async def acquire(self) -> None: async def _fetch_imdb( - client: httpx.AsyncClient, + client: AsyncHTTPClient, imdb_id: str, *, cache: IMDbCache | None, @@ -882,7 +930,7 @@ async def _fetch_imdb( async def _fetch_imdb_batch( - client: httpx.AsyncClient, + client: AsyncHTTPClient, imdb_ids: Sequence[str], *, cache: IMDbCache | None, diff --git a/mcp_plex/loader/pipeline/ingestion.py b/mcp_plex/loader/pipeline/ingestion.py index 5898029..795f560 100644 --- a/mcp_plex/loader/pipeline/ingestion.py +++ b/mcp_plex/loader/pipeline/ingestion.py @@ -23,6 +23,7 @@ enqueue_nowait, ) +from plexapi.library import LibrarySection from plexapi.server import PlexServer from plexapi.video import Episode, Movie, Season, Show @@ -171,12 +172,9 @@ async def _ingest_plex( library = plex_server.library def _log_discovered_count( - *, section: object, descriptor: str + *, section: LibrarySection, descriptor: str ) -> int | None: - try: - total = getattr(section, "totalSize") # type: ignore[assignment] - except Exception: # pragma: no cover - defensive guard - total = None + total = getattr(section, "totalSize", None) if isinstance(total, int): logger.info( "Discovered %d Plex %s(s) for ingestion.", diff --git a/mcp_plex/loader/pipeline/orchestrator.py b/mcp_plex/loader/pipeline/orchestrator.py index 38ab417..126e71a 100644 --- a/mcp_plex/loader/pipeline/orchestrator.py +++ b/mcp_plex/loader/pipeline/orchestrator.py @@ -6,7 +6,7 @@ import inspect import logging from dataclasses import dataclass -from typing import Awaitable, Callable +from typing import Awaitable, Callable, Protocol from .channels import IngestQueue, PersistenceQueue @@ -30,15 +30,30 @@ def __init__(self, spec: _StageSpec, error: BaseException) -> None: self.error = error +class IngestionStageProtocol(Protocol): + async def run(self) -> None: + ... + + +class EnrichmentStageProtocol(Protocol): + async def run(self) -> None: + ... + + +class PersistenceStageProtocol(Protocol): + async def run(self, worker_id: int) -> None: + ... + + class LoaderOrchestrator: """Run the ingestion, enrichment, and persistence stages with supervision.""" def __init__( self, *, - ingestion_stage: object, - enrichment_stage: object, - persistence_stage: object, + ingestion_stage: IngestionStageProtocol, + enrichment_stage: EnrichmentStageProtocol, + persistence_stage: PersistenceStageProtocol, ingest_queue: IngestQueue, persistence_queue: PersistenceQueue, persistence_worker_count: int = 1, @@ -162,7 +177,7 @@ async def _handle_failures(self, failures: list[_StageFailure]) -> None: # the stage-specific tests which verify cancellation side-effects. await asyncio.sleep(0) - def _drain_queue(self, queue: asyncio.Queue[object]) -> int: + def _drain_queue(self, queue: IngestQueue | PersistenceQueue) -> int: """Remove any queued items so cancellation does not leave stale work.""" drained = 0 diff --git a/mcp_plex/loader/pipeline/persistence.py b/mcp_plex/loader/pipeline/persistence.py index 3a821be..8c3519e 100644 --- a/mcp_plex/loader/pipeline/persistence.py +++ b/mcp_plex/loader/pipeline/persistence.py @@ -4,7 +4,7 @@ import asyncio import logging -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Sequence +from typing import Awaitable, Callable, Sequence, TypeAlias from .channels import ( PERSIST_DONE, @@ -14,13 +14,22 @@ ) from ...common.validation import require_positive -if TYPE_CHECKING: # pragma: no cover - typing helpers only +try: # pragma: no cover - allow import to fail when qdrant_client is absent from qdrant_client import AsyncQdrantClient, models +except ModuleNotFoundError: # pragma: no cover - tooling without qdrant installed + class AsyncQdrantClient: # type: ignore[too-few-public-methods] + """Fallback stub used when qdrant_client is unavailable.""" - PersistencePayload = list[models.PointStruct] -else: # pragma: no cover - runtime fallback when qdrant_client is absent - AsyncQdrantClient = Any # type: ignore[assignment] - PersistencePayload = list[Any] + pass + + class _ModelsStub: # type: ignore[too-few-public-methods] + class PointStruct: # type: ignore[too-few-public-methods] + ... + + models = _ModelsStub() # type: ignore[assignment] + + +PersistencePayload: TypeAlias = list["models.PointStruct"] class PersistenceStage: diff --git a/tests/test_enrichment_stage.py b/tests/test_enrichment_stage.py index cfc6d0b..11cc4ea 100644 --- a/tests/test_enrichment_stage.py +++ b/tests/test_enrichment_stage.py @@ -2,6 +2,8 @@ import logging from typing import Any +import httpx + import pytest from mcp_plex.common.types import ( @@ -27,12 +29,26 @@ ) +class _StubHTTPClient: + async def get( + self, + url: str, + *, + params: dict[str, str] | None = None, + headers: dict[str, str] | None = None, + ) -> httpx.Response: + raise AssertionError(f"Unexpected HTTP request: {url}") + + async def aclose(self) -> None: + pass + + def test_enrichment_stage_logger_name() -> None: async def scenario() -> str: ingest_queue: asyncio.Queue = asyncio.Queue() persistence_queue: asyncio.Queue = asyncio.Queue() stage = EnrichmentStage( - http_client_factory=lambda: object(), + http_client_factory=lambda: _StubHTTPClient(), tmdb_api_key="tmdb", ingest_queue=ingest_queue, persistence_queue=persistence_queue, @@ -52,7 +68,7 @@ async def scenario() -> bool: persistence_queue: asyncio.Queue = asyncio.Queue() retry_queue = IMDbRetryQueue() stage = EnrichmentStage( - http_client_factory=lambda: object(), + http_client_factory=lambda: _StubHTTPClient(), tmdb_api_key="tmdb", ingest_queue=ingest_queue, persistence_queue=persistence_queue, @@ -71,7 +87,7 @@ async def scenario() -> IMDbRetryQueue: ingest_queue: asyncio.Queue = asyncio.Queue() persistence_queue: asyncio.Queue = asyncio.Queue() stage = EnrichmentStage( - http_client_factory=lambda: object(), + http_client_factory=lambda: _StubHTTPClient(), tmdb_api_key="tmdb", ingest_queue=ingest_queue, persistence_queue=persistence_queue, @@ -125,7 +141,7 @@ async def scenario() -> list[AggregatedItem]: ingest_queue: asyncio.Queue = asyncio.Queue() persistence_queue: asyncio.Queue = asyncio.Queue() stage = EnrichmentStage( - http_client_factory=lambda: object(), + http_client_factory=lambda: _StubHTTPClient(), tmdb_api_key="token", ingest_queue=ingest_queue, persistence_queue=persistence_queue, @@ -135,7 +151,7 @@ async def scenario() -> list[AggregatedItem]: ) movie = _FakeMovie("1", imdb_id="tt0001", tmdb_id="101") return await asyncio.wait_for( - stage._enrich_movies(object(), [movie]), timeout=1 + stage._enrich_movies(_StubHTTPClient(), [movie]), timeout=1 ) result = asyncio.run(scenario()) @@ -226,8 +242,9 @@ def __init__( self.collections: list[Any] = [] -class _FakeClient: +class _FakeClient(_StubHTTPClient): def __init__(self, log: list[str]) -> None: + super().__init__() self._log = log async def __aenter__(self) -> "_FakeClient": @@ -599,7 +616,7 @@ async def scenario() -> list[list[AggregatedItem] | None]: ingest_queue: asyncio.Queue = asyncio.Queue() persistence_queue: asyncio.Queue = asyncio.Queue() stage = EnrichmentStage( - http_client_factory=lambda: object(), + http_client_factory=lambda: _StubHTTPClient(), tmdb_api_key="token", ingest_queue=ingest_queue, persistence_queue=persistence_queue, @@ -712,7 +729,7 @@ async def scenario() -> list[list[AggregatedItem] | None]: ingest_queue: asyncio.Queue = asyncio.Queue() persistence_queue: asyncio.Queue = asyncio.Queue() stage = EnrichmentStage( - http_client_factory=lambda: object(), + http_client_factory=lambda: _StubHTTPClient(), tmdb_api_key="token", ingest_queue=ingest_queue, persistence_queue=persistence_queue, @@ -771,7 +788,7 @@ async def scenario() -> tuple[list[list[AggregatedItem] | None], list[Any], list logger.propagate = False stage = EnrichmentStage( - http_client_factory=lambda: object(), + http_client_factory=lambda: _StubHTTPClient(), tmdb_api_key="", ingest_queue=ingest_queue, persistence_queue=persistence_queue, @@ -846,7 +863,7 @@ async def scenario() -> tuple[list[list[str]], int]: retry_queue = IMDbRetryQueue(["tt1", "tt2"]) stage = EnrichmentStage( - http_client_factory=lambda: object(), + http_client_factory=lambda: _StubHTTPClient(), tmdb_api_key="", ingest_queue=ingest_queue, persistence_queue=persistence_queue, @@ -899,7 +916,7 @@ async def scenario() -> tuple[list[list[AggregatedItem] | None], int, list[list[ retry_queue = IMDbRetryQueue() stage = EnrichmentStage( - http_client_factory=lambda: object(), + http_client_factory=lambda: _StubHTTPClient(), tmdb_api_key="", ingest_queue=ingest_queue, persistence_queue=persistence_queue, From a374476b4c2002e30717351ccf24296bbdc076e5 Mon Sep 17 00:00:00 2001 From: Teagan Glenn Date: Mon, 6 Oct 2025 22:16:52 -0600 Subject: [PATCH 2/2] fix(loader): rehydrate imdb cache payloads --- docker/pyproject.deps.toml | 2 +- mcp_plex/loader/imdb_cache.py | 24 ++++++++++++++++++++---- pyproject.toml | 2 +- uv.lock | 2 +- 4 files changed, 23 insertions(+), 7 deletions(-) diff --git a/docker/pyproject.deps.toml b/docker/pyproject.deps.toml index 2e96cd2..5829ba9 100644 --- a/docker/pyproject.deps.toml +++ b/docker/pyproject.deps.toml @@ -1,6 +1,6 @@ [project] name = "mcp-plex" -version = "1.0.17" +version = "1.0.18" requires-python = ">=3.11,<3.13" dependencies = [ "fastmcp>=2.11.2", diff --git a/mcp_plex/loader/imdb_cache.py b/mcp_plex/loader/imdb_cache.py index ad833b4..f06b6a4 100644 --- a/mcp_plex/loader/imdb_cache.py +++ b/mcp_plex/loader/imdb_cache.py @@ -5,6 +5,8 @@ from pathlib import Path from typing import TypeAlias, cast +from pydantic import ValidationError + from ..common.types import IMDbTitle JSONScalar: TypeAlias = str | int | float | bool | None @@ -44,10 +46,24 @@ def __init__(self, path: Path) -> None: ) else: if isinstance(loaded, dict): - self._data = { - str(key): cast(CachedIMDbPayload, value) - for key, value in loaded.items() - } + hydrated: dict[str, CachedIMDbPayload] = {} + for key, value in loaded.items(): + imdb_id = str(key) + payload: CachedIMDbPayload + if isinstance(value, dict): + try: + payload = IMDbTitle.model_validate(value) + except ValidationError as exc: + self._logger.debug( + "Failed to validate cached IMDb payload for %s; falling back to raw JSON.", + imdb_id, + exc_info=exc, + ) + payload = cast(JSONValue, value) + else: + payload = cast(JSONValue, value) + hydrated[imdb_id] = payload + self._data = hydrated else: self._logger.warning( "IMDb cache at %s did not contain an object; ignoring its contents.", diff --git a/pyproject.toml b/pyproject.toml index be4dcac..0110439 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mcp-plex" -version = "1.0.17" +version = "1.0.18" description = "Plex-Oriented Model Context Protocol Server" requires-python = ">=3.11,<3.13" diff --git a/uv.lock b/uv.lock index a988d14..d691f41 100644 --- a/uv.lock +++ b/uv.lock @@ -730,7 +730,7 @@ wheels = [ [[package]] name = "mcp-plex" -version = "1.0.17" +version = "1.0.18" source = { editable = "." } dependencies = [ { name = "fastapi" },