Skip to content

Commit c70f559

Browse files
authored
test(loader): cover batch gather helper (#42)
1 parent 0dc0fc1 commit c70f559

File tree

5 files changed

+62
-11
lines changed

5 files changed

+62
-11
lines changed

mcp_plex/loader.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
import sys
88
from pathlib import Path
9-
from typing import List, Optional
9+
from typing import Awaitable, List, Optional, Sequence, TypeVar
1010

1111
import click
1212
import httpx
@@ -36,6 +36,20 @@
3636

3737
logger = logging.getLogger(__name__)
3838

39+
T = TypeVar("T")
40+
41+
42+
async def _gather_in_batches(
43+
tasks: Sequence[Awaitable[T]], batch_size: int
44+
) -> List[T]:
45+
"""Gather awaitable tasks in fixed-size batches."""
46+
47+
results: List[T] = []
48+
for i in range(0, len(tasks), batch_size):
49+
batch = tasks[i : i + batch_size]
50+
results.extend(await asyncio.gather(*batch))
51+
return results
52+
3953

4054
async def _fetch_imdb(client: httpx.AsyncClient, imdb_id: str) -> Optional[IMDbTitle]:
4155
"""Fetch metadata for an IMDb ID."""
@@ -137,7 +151,9 @@ def _build_plex_item(item: PlexPartialObject) -> PlexItem:
137151
)
138152

139153

140-
async def _load_from_plex(server: PlexServer, tmdb_api_key: str) -> List[AggregatedItem]:
154+
async def _load_from_plex(
155+
server: PlexServer, tmdb_api_key: str, *, batch_size: int = 50
156+
) -> List[AggregatedItem]:
141157
"""Load items from a live Plex server."""
142158

143159
async def _augment_movie(client: httpx.AsyncClient, movie: PlexPartialObject) -> AggregatedItem:
@@ -174,11 +190,9 @@ async def _augment_episode(
174190
results: List[AggregatedItem] = []
175191
async with httpx.AsyncClient(timeout=30) as client:
176192
movie_section = server.library.section("Movies")
177-
movie_tasks = [
178-
_augment_movie(client, movie) for movie in movie_section.all()
179-
]
193+
movie_tasks = [_augment_movie(client, movie) for movie in movie_section.all()]
180194
if movie_tasks:
181-
results.extend(await asyncio.gather(*movie_tasks))
195+
results.extend(await _gather_in_batches(movie_tasks, batch_size))
182196

183197
show_section = server.library.section("TV Shows")
184198
for show in show_section.all():
@@ -187,10 +201,11 @@ async def _augment_episode(
187201
if show_ids.tmdb:
188202
show_tmdb = await _fetch_tmdb_show(client, show_ids.tmdb, tmdb_api_key)
189203
episode_tasks = [
190-
_augment_episode(client, episode, show_tmdb) for episode in show.episodes()
204+
_augment_episode(client, episode, show_tmdb)
205+
for episode in show.episodes()
191206
]
192207
if episode_tasks:
193-
results.extend(await asyncio.gather(*episode_tasks))
208+
results.extend(await _gather_in_batches(episode_tasks, batch_size))
194209
return results
195210

196211

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "mcp-plex"
7-
version = "0.26.7"
7+
version = "0.26.8"
88

99
description = "Plex-Oriented Model Context Protocol Server"
1010
requires-python = ">=3.11,<3.13"

tests/test_gather_in_batches.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import asyncio
2+
3+
from mcp_plex import loader
4+
5+
6+
async def _echo(value: int) -> int:
7+
await asyncio.sleep(0)
8+
return value
9+
10+
11+
def test_gather_in_batches(monkeypatch):
12+
calls: list[int] = []
13+
orig_gather = asyncio.gather
14+
15+
async def fake_gather(*coros):
16+
calls.append(len(coros))
17+
return await orig_gather(*coros)
18+
19+
monkeypatch.setattr(asyncio, "gather", fake_gather)
20+
21+
tasks = [_echo(i) for i in range(5)]
22+
results = asyncio.run(loader._gather_in_batches(tasks, 2))
23+
24+
assert results == list(range(5))
25+
assert calls == [2, 2, 1]
26+

tests/test_load_from_plex.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,17 @@ async def handler(request):
9494
lambda *args, **kwargs: orig_client(transport=transport),
9595
)
9696

97-
items = asyncio.run(loader._load_from_plex(server, "key"))
97+
calls = []
98+
orig_batch = loader._gather_in_batches
99+
100+
async def fake_batch(tasks, batch_size):
101+
calls.append((len(tasks), batch_size))
102+
return await orig_batch(tasks, batch_size)
103+
104+
monkeypatch.setattr(loader, "_gather_in_batches", fake_batch)
105+
106+
items = asyncio.run(loader._load_from_plex(server, "key", batch_size=1))
107+
assert calls == [(1, 1), (2, 1)]
98108
assert len(items) == 3
99109
assert items[0].tmdb and items[0].tmdb.id == 27205
100110
assert items[1].tmdb and items[1].tmdb.id == 62085

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)