From 716bcea4f77f11c903714013a9d526b0eaa334ee Mon Sep 17 00:00:00 2001 From: Teagan Glenn Date: Sat, 13 Sep 2025 22:14:17 -0600 Subject: [PATCH] test(loader): cover batch gather helper --- mcp_plex/loader.py | 31 +++++++++++++++++++++++-------- pyproject.toml | 2 +- tests/test_gather_in_batches.py | 26 ++++++++++++++++++++++++++ tests/test_load_from_plex.py | 12 +++++++++++- uv.lock | 2 +- 5 files changed, 62 insertions(+), 11 deletions(-) create mode 100644 tests/test_gather_in_batches.py diff --git a/mcp_plex/loader.py b/mcp_plex/loader.py index 66017df..e8f3c44 100644 --- a/mcp_plex/loader.py +++ b/mcp_plex/loader.py @@ -6,7 +6,7 @@ import logging import sys from pathlib import Path -from typing import List, Optional +from typing import Awaitable, List, Optional, Sequence, TypeVar import click import httpx @@ -36,6 +36,20 @@ logger = logging.getLogger(__name__) +T = TypeVar("T") + + +async def _gather_in_batches( + tasks: Sequence[Awaitable[T]], batch_size: int +) -> List[T]: + """Gather awaitable tasks in fixed-size batches.""" + + results: List[T] = [] + for i in range(0, len(tasks), batch_size): + batch = tasks[i : i + batch_size] + results.extend(await asyncio.gather(*batch)) + return results + async def _fetch_imdb(client: httpx.AsyncClient, imdb_id: str) -> Optional[IMDbTitle]: """Fetch metadata for an IMDb ID.""" @@ -137,7 +151,9 @@ def _build_plex_item(item: PlexPartialObject) -> PlexItem: ) -async def _load_from_plex(server: PlexServer, tmdb_api_key: str) -> List[AggregatedItem]: +async def _load_from_plex( + server: PlexServer, tmdb_api_key: str, *, batch_size: int = 50 +) -> List[AggregatedItem]: """Load items from a live Plex server.""" async def _augment_movie(client: httpx.AsyncClient, movie: PlexPartialObject) -> AggregatedItem: @@ -174,11 +190,9 @@ async def _augment_episode( results: List[AggregatedItem] = [] async with httpx.AsyncClient(timeout=30) as client: movie_section = server.library.section("Movies") - movie_tasks = [ - _augment_movie(client, movie) for movie in movie_section.all() - ] + movie_tasks = [_augment_movie(client, movie) for movie in movie_section.all()] if movie_tasks: - results.extend(await asyncio.gather(*movie_tasks)) + results.extend(await _gather_in_batches(movie_tasks, batch_size)) show_section = server.library.section("TV Shows") for show in show_section.all(): @@ -187,10 +201,11 @@ async def _augment_episode( if show_ids.tmdb: show_tmdb = await _fetch_tmdb_show(client, show_ids.tmdb, tmdb_api_key) episode_tasks = [ - _augment_episode(client, episode, show_tmdb) for episode in show.episodes() + _augment_episode(client, episode, show_tmdb) + for episode in show.episodes() ] if episode_tasks: - results.extend(await asyncio.gather(*episode_tasks)) + results.extend(await _gather_in_batches(episode_tasks, batch_size)) return results diff --git a/pyproject.toml b/pyproject.toml index 7ec2fe1..502f27e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mcp-plex" -version = "0.26.7" +version = "0.26.8" description = "Plex-Oriented Model Context Protocol Server" requires-python = ">=3.11,<3.13" diff --git a/tests/test_gather_in_batches.py b/tests/test_gather_in_batches.py new file mode 100644 index 0000000..8863b08 --- /dev/null +++ b/tests/test_gather_in_batches.py @@ -0,0 +1,26 @@ +import asyncio + +from mcp_plex import loader + + +async def _echo(value: int) -> int: + await asyncio.sleep(0) + return value + + +def test_gather_in_batches(monkeypatch): + calls: list[int] = [] + orig_gather = asyncio.gather + + async def fake_gather(*coros): + calls.append(len(coros)) + return await orig_gather(*coros) + + monkeypatch.setattr(asyncio, "gather", fake_gather) + + tasks = [_echo(i) for i in range(5)] + results = asyncio.run(loader._gather_in_batches(tasks, 2)) + + assert results == list(range(5)) + assert calls == [2, 2, 1] + diff --git a/tests/test_load_from_plex.py b/tests/test_load_from_plex.py index d7b757b..10b26e4 100644 --- a/tests/test_load_from_plex.py +++ b/tests/test_load_from_plex.py @@ -94,7 +94,17 @@ async def handler(request): lambda *args, **kwargs: orig_client(transport=transport), ) - items = asyncio.run(loader._load_from_plex(server, "key")) + calls = [] + orig_batch = loader._gather_in_batches + + async def fake_batch(tasks, batch_size): + calls.append((len(tasks), batch_size)) + return await orig_batch(tasks, batch_size) + + monkeypatch.setattr(loader, "_gather_in_batches", fake_batch) + + items = asyncio.run(loader._load_from_plex(server, "key", batch_size=1)) + assert calls == [(1, 1), (2, 1)] assert len(items) == 3 assert items[0].tmdb and items[0].tmdb.id == 27205 assert items[1].tmdb and items[1].tmdb.id == 62085 diff --git a/uv.lock b/uv.lock index 1f50dc9..c49501c 100644 --- a/uv.lock +++ b/uv.lock @@ -690,7 +690,7 @@ wheels = [ [[package]] name = "mcp-plex" -version = "0.26.7" +version = "0.26.8" source = { editable = "." } dependencies = [ { name = "fastapi" },