diff --git a/mcp_plex/loader/pipeline/enrichment.py b/mcp_plex/loader/pipeline/enrichment.py index deb0de5..01043db 100644 --- a/mcp_plex/loader/pipeline/enrichment.py +++ b/mcp_plex/loader/pipeline/enrichment.py @@ -527,31 +527,44 @@ async def _enrich_movies( movie_ids = [_extract_external_ids(movie) for movie in movies] imdb_ids = [ids.imdb for ids in movie_ids if ids.imdb] - retry_snapshot: set[str] = set() - imdb_map = {} + imdb_future: asyncio.Task[dict[str, IMDbTitle | None]] | None = None if imdb_ids: - imdb_map = await _fetch_imdb_batch( - client, - imdb_ids, - cache=self._imdb_cache, - throttle=self._imdb_throttle, - max_retries=self._imdb_max_retries, - backoff=self._imdb_backoff, - retry_queue=self._imdb_retry_queue, - batch_limit=self._imdb_batch_limit, + imdb_future = asyncio.create_task( + _fetch_imdb_batch( + client, + imdb_ids, + cache=self._imdb_cache, + throttle=self._imdb_throttle, + max_retries=self._imdb_max_retries, + backoff=self._imdb_backoff, + retry_queue=self._imdb_retry_queue, + batch_limit=self._imdb_batch_limit, + ) ) - retry_snapshot = set(self._imdb_retry_queue.snapshot()) - tmdb_results: list[Any] = [] api_key = self._tmdb_api_key + tmdb_tasks: list[asyncio.Task[Any]] = [] if api_key: - tmdb_tasks = [ - _fetch_tmdb_movie(client, ids.tmdb, api_key) - for ids in movie_ids - if ids.tmdb - ] - if tmdb_tasks: - tmdb_results = await asyncio.gather(*tmdb_tasks) + for ids in movie_ids: + if not ids.tmdb: + continue + tmdb_tasks.append( + asyncio.create_task( + _fetch_tmdb_movie(client, ids.tmdb, api_key) + ) + ) + + imdb_map: dict[str, IMDbTitle | None] = {} + retry_snapshot: set[str] = set() + tmdb_results: list[Any] = [] + if imdb_future is not None: + combined_results = await asyncio.gather(imdb_future, *tmdb_tasks) + imdb_map = combined_results[0] + tmdb_results = list(combined_results[1:]) + retry_snapshot = set(self._imdb_retry_queue.snapshot()) + elif tmdb_tasks: + tmdb_results = [await task for task in tmdb_tasks] + tmdb_iter = iter(tmdb_results) aggregated: list[AggregatedItem] = [] @@ -588,31 +601,52 @@ async def _enrich_episodes( """Fetch external metadata for *episodes* and aggregate the results.""" show_ids = _extract_external_ids(show) + imdb_future: asyncio.Task[dict[str, IMDbTitle | None]] | None = None + show_future: asyncio.Task[TMDBShow | None] | None = None show_tmdb: TMDBShow | None = None if show_ids.tmdb: - show_tmdb = await self._get_tmdb_show(client, show_ids.tmdb) + show_future = asyncio.create_task( + self._get_tmdb_show(client, show_ids.tmdb) + ) episode_ids = [_extract_external_ids(ep) for ep in episodes] imdb_ids = [ids.imdb for ids in episode_ids if ids.imdb] - retry_snapshot: set[str] = set() - imdb_map = {} if imdb_ids: - imdb_map = await _fetch_imdb_batch( - client, - imdb_ids, - cache=self._imdb_cache, - throttle=self._imdb_throttle, - max_retries=self._imdb_max_retries, - backoff=self._imdb_backoff, - retry_queue=self._imdb_retry_queue, - batch_limit=self._imdb_batch_limit, + imdb_future = asyncio.create_task( + _fetch_imdb_batch( + client, + imdb_ids, + cache=self._imdb_cache, + throttle=self._imdb_throttle, + max_retries=self._imdb_max_retries, + backoff=self._imdb_backoff, + retry_queue=self._imdb_retry_queue, + batch_limit=self._imdb_batch_limit, + ) ) - retry_snapshot = set(self._imdb_retry_queue.snapshot()) - tmdb_results: list[TMDBEpisode | None] = [None] * len(episodes) + if show_future is not None: + show_tmdb = await show_future + + tmdb_future: asyncio.Task[list[TMDBEpisode | None]] | None = None if show_tmdb: - tmdb_results = await self._bulk_lookup_tmdb_episodes( - client, show_tmdb, episodes + tmdb_future = asyncio.create_task( + self._bulk_lookup_tmdb_episodes(client, show_tmdb, episodes) + ) + + imdb_map: dict[str, IMDbTitle | None] = {} + retry_snapshot: set[str] = set() + tmdb_results: list[TMDBEpisode | None] = [None] * len(episodes) + if imdb_future and tmdb_future: + imdb_map, tmdb_results = await asyncio.gather( + imdb_future, tmdb_future ) + retry_snapshot = set(self._imdb_retry_queue.snapshot()) + elif imdb_future: + imdb_map = await imdb_future + retry_snapshot = set(self._imdb_retry_queue.snapshot()) + elif tmdb_future: + tmdb_results = await tmdb_future + tmdb_iter = iter(tmdb_results) aggregated: list[AggregatedItem] = [] diff --git a/tests/test_enrichment_stage.py b/tests/test_enrichment_stage.py index 9e4dabc..cfc6d0b 100644 --- a/tests/test_enrichment_stage.py +++ b/tests/test_enrichment_stage.py @@ -2,6 +2,8 @@ import logging from typing import Any +import pytest + from mcp_plex.common.types import ( AggregatedItem, IMDbTitle, @@ -83,6 +85,65 @@ async def scenario() -> IMDbRetryQueue: assert isinstance(retry_queue, IMDbRetryQueue) +def test_enrich_movies_runs_tmdb_and_imdb_requests_in_parallel( + monkeypatch: pytest.MonkeyPatch, +) -> None: + events: dict[str, asyncio.Event] = {} + + async def fake_fetch_imdb_batch(client, imdb_ids, **kwargs): + events["imdb_started"].set() + await asyncio.wait_for(events["tmdb_started"].wait(), timeout=1) + return { + imdb_id: IMDbTitle( + id=imdb_id, + type="movie", + primaryTitle=f"IMDb {imdb_id}", + ) + for imdb_id in imdb_ids + } + + async def fake_fetch_tmdb_movie(client, tmdb_id, api_key): + events["tmdb_started"].set() + await asyncio.wait_for(events["imdb_started"].wait(), timeout=1) + return TMDBMovie.model_validate({ + "id": int(tmdb_id), + "title": f"TMDb {tmdb_id}", + }) + + monkeypatch.setattr( + "mcp_plex.loader.pipeline.enrichment._fetch_imdb_batch", + fake_fetch_imdb_batch, + ) + monkeypatch.setattr( + "mcp_plex.loader.pipeline.enrichment._fetch_tmdb_movie", + fake_fetch_tmdb_movie, + ) + + async def scenario() -> list[AggregatedItem]: + events["imdb_started"] = asyncio.Event() + events["tmdb_started"] = asyncio.Event() + ingest_queue: asyncio.Queue = asyncio.Queue() + persistence_queue: asyncio.Queue = asyncio.Queue() + stage = EnrichmentStage( + http_client_factory=lambda: object(), + tmdb_api_key="token", + ingest_queue=ingest_queue, + persistence_queue=persistence_queue, + imdb_retry_queue=IMDbRetryQueue(), + movie_batch_size=5, + episode_batch_size=5, + ) + movie = _FakeMovie("1", imdb_id="tt0001", tmdb_id="101") + return await asyncio.wait_for( + stage._enrich_movies(object(), [movie]), timeout=1 + ) + + result = asyncio.run(scenario()) + + assert result[0].imdb is not None + assert result[0].tmdb is not None + + class _FakeGuid: def __init__(self, guid: str) -> None: self.id = guid