Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 70 additions & 36 deletions mcp_plex/loader/pipeline/enrichment.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,31 +527,44 @@ async def _enrich_movies(

movie_ids = [_extract_external_ids(movie) for movie in movies]
imdb_ids = [ids.imdb for ids in movie_ids if ids.imdb]
retry_snapshot: set[str] = set()
imdb_map = {}
imdb_future: asyncio.Task[dict[str, IMDbTitle | None]] | None = None
if imdb_ids:
imdb_map = await _fetch_imdb_batch(
client,
imdb_ids,
cache=self._imdb_cache,
throttle=self._imdb_throttle,
max_retries=self._imdb_max_retries,
backoff=self._imdb_backoff,
retry_queue=self._imdb_retry_queue,
batch_limit=self._imdb_batch_limit,
imdb_future = asyncio.create_task(
_fetch_imdb_batch(
client,
imdb_ids,
cache=self._imdb_cache,
throttle=self._imdb_throttle,
max_retries=self._imdb_max_retries,
backoff=self._imdb_backoff,
retry_queue=self._imdb_retry_queue,
batch_limit=self._imdb_batch_limit,
)
)
retry_snapshot = set(self._imdb_retry_queue.snapshot())

tmdb_results: list[Any] = []
api_key = self._tmdb_api_key
tmdb_tasks: list[asyncio.Task[Any]] = []
if api_key:
tmdb_tasks = [
_fetch_tmdb_movie(client, ids.tmdb, api_key)
for ids in movie_ids
if ids.tmdb
]
if tmdb_tasks:
tmdb_results = await asyncio.gather(*tmdb_tasks)
for ids in movie_ids:
if not ids.tmdb:
continue
tmdb_tasks.append(
asyncio.create_task(
_fetch_tmdb_movie(client, ids.tmdb, api_key)
)
)

imdb_map: dict[str, IMDbTitle | None] = {}
retry_snapshot: set[str] = set()
tmdb_results: list[Any] = []
if imdb_future is not None:
combined_results = await asyncio.gather(imdb_future, *tmdb_tasks)
imdb_map = combined_results[0]
tmdb_results = list(combined_results[1:])
retry_snapshot = set(self._imdb_retry_queue.snapshot())
elif tmdb_tasks:
tmdb_results = [await task for task in tmdb_tasks]

tmdb_iter = iter(tmdb_results)

aggregated: list[AggregatedItem] = []
Expand Down Expand Up @@ -588,31 +601,52 @@ async def _enrich_episodes(
"""Fetch external metadata for *episodes* and aggregate the results."""

show_ids = _extract_external_ids(show)
imdb_future: asyncio.Task[dict[str, IMDbTitle | None]] | None = None
show_future: asyncio.Task[TMDBShow | None] | None = None
show_tmdb: TMDBShow | None = None
if show_ids.tmdb:
show_tmdb = await self._get_tmdb_show(client, show_ids.tmdb)
show_future = asyncio.create_task(
self._get_tmdb_show(client, show_ids.tmdb)
)
episode_ids = [_extract_external_ids(ep) for ep in episodes]
imdb_ids = [ids.imdb for ids in episode_ids if ids.imdb]
retry_snapshot: set[str] = set()
imdb_map = {}
if imdb_ids:
imdb_map = await _fetch_imdb_batch(
client,
imdb_ids,
cache=self._imdb_cache,
throttle=self._imdb_throttle,
max_retries=self._imdb_max_retries,
backoff=self._imdb_backoff,
retry_queue=self._imdb_retry_queue,
batch_limit=self._imdb_batch_limit,
imdb_future = asyncio.create_task(
_fetch_imdb_batch(
client,
imdb_ids,
cache=self._imdb_cache,
throttle=self._imdb_throttle,
max_retries=self._imdb_max_retries,
backoff=self._imdb_backoff,
retry_queue=self._imdb_retry_queue,
batch_limit=self._imdb_batch_limit,
)
)
retry_snapshot = set(self._imdb_retry_queue.snapshot())

tmdb_results: list[TMDBEpisode | None] = [None] * len(episodes)
if show_future is not None:
show_tmdb = await show_future

tmdb_future: asyncio.Task[list[TMDBEpisode | None]] | None = None
if show_tmdb:
tmdb_results = await self._bulk_lookup_tmdb_episodes(
client, show_tmdb, episodes
tmdb_future = asyncio.create_task(
self._bulk_lookup_tmdb_episodes(client, show_tmdb, episodes)
)

imdb_map: dict[str, IMDbTitle | None] = {}
retry_snapshot: set[str] = set()
tmdb_results: list[TMDBEpisode | None] = [None] * len(episodes)
if imdb_future and tmdb_future:
imdb_map, tmdb_results = await asyncio.gather(
imdb_future, tmdb_future
)
retry_snapshot = set(self._imdb_retry_queue.snapshot())
elif imdb_future:
imdb_map = await imdb_future
retry_snapshot = set(self._imdb_retry_queue.snapshot())
elif tmdb_future:
tmdb_results = await tmdb_future

tmdb_iter = iter(tmdb_results)

aggregated: list[AggregatedItem] = []
Expand Down
61 changes: 61 additions & 0 deletions tests/test_enrichment_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
from typing import Any

import pytest

from mcp_plex.common.types import (
AggregatedItem,
IMDbTitle,
Expand Down Expand Up @@ -83,6 +85,65 @@ async def scenario() -> IMDbRetryQueue:
assert isinstance(retry_queue, IMDbRetryQueue)


def test_enrich_movies_runs_tmdb_and_imdb_requests_in_parallel(
monkeypatch: pytest.MonkeyPatch,
) -> None:
events: dict[str, asyncio.Event] = {}

async def fake_fetch_imdb_batch(client, imdb_ids, **kwargs):
events["imdb_started"].set()
await asyncio.wait_for(events["tmdb_started"].wait(), timeout=1)
return {
imdb_id: IMDbTitle(
id=imdb_id,
type="movie",
primaryTitle=f"IMDb {imdb_id}",
)
for imdb_id in imdb_ids
}

async def fake_fetch_tmdb_movie(client, tmdb_id, api_key):
events["tmdb_started"].set()
await asyncio.wait_for(events["imdb_started"].wait(), timeout=1)
return TMDBMovie.model_validate({
"id": int(tmdb_id),
"title": f"TMDb {tmdb_id}",
})

monkeypatch.setattr(
"mcp_plex.loader.pipeline.enrichment._fetch_imdb_batch",
fake_fetch_imdb_batch,
)
monkeypatch.setattr(
"mcp_plex.loader.pipeline.enrichment._fetch_tmdb_movie",
fake_fetch_tmdb_movie,
)

async def scenario() -> list[AggregatedItem]:
events["imdb_started"] = asyncio.Event()
events["tmdb_started"] = asyncio.Event()
ingest_queue: asyncio.Queue = asyncio.Queue()
persistence_queue: asyncio.Queue = asyncio.Queue()
stage = EnrichmentStage(
http_client_factory=lambda: object(),
tmdb_api_key="token",
ingest_queue=ingest_queue,
persistence_queue=persistence_queue,
imdb_retry_queue=IMDbRetryQueue(),
movie_batch_size=5,
episode_batch_size=5,
)
movie = _FakeMovie("1", imdb_id="tt0001", tmdb_id="101")
return await asyncio.wait_for(
stage._enrich_movies(object(), [movie]), timeout=1
)

result = asyncio.run(scenario())

assert result[0].imdb is not None
assert result[0].tmdb is not None


class _FakeGuid:
def __init__(self, guid: str) -> None:
self.id = guid
Expand Down