diff --git a/docker/pyproject.deps.toml b/docker/pyproject.deps.toml index fffdb74..679322a 100644 --- a/docker/pyproject.deps.toml +++ b/docker/pyproject.deps.toml @@ -1,6 +1,6 @@ [project] name = "mcp-plex" -version = "0.26.53" +version = "0.26.57" requires-python = ">=3.11,<3.13" dependencies = [ "fastmcp>=2.11.2", diff --git a/mcp_plex/loader/__init__.py b/mcp_plex/loader/__init__.py index 0fb3d6a..7aeb7b6 100644 --- a/mcp_plex/loader/__init__.py +++ b/mcp_plex/loader/__init__.py @@ -9,6 +9,7 @@ import warnings from collections import deque from pathlib import Path +from types import TracebackType from typing import AsyncIterator, Awaitable, Iterable, List, Optional, Sequence, TypeVar import click @@ -909,35 +910,51 @@ async def run( item_iter = _iter_from_plex(server, tmdb_api_key) points_buffer: List[models.PointStruct] = [] - upsert_tasks: set[asyncio.Task[None]] = set() qdrant_retry_queue: asyncio.Queue[list[models.PointStruct]] = asyncio.Queue() max_concurrent_upserts = _require_positive( _qdrant_max_concurrent_upserts, name="max_concurrent_upserts" ) - - async def _schedule_upsert(batch: List[models.PointStruct]) -> None: - logger.info( - "Upserting %d points into Qdrant collection %s in batches of %d", - len(batch), - collection_name, - _qdrant_batch_size, - ) - task = asyncio.create_task( - _upsert_in_batches( - client, + upsert_queue: asyncio.Queue[List[models.PointStruct] | None] = asyncio.Queue() + upsert_capacity = asyncio.Semaphore(max_concurrent_upserts) + batches_enqueued = 0 + worker_error: Exception | None = None + worker_error_tb: TracebackType | None = None + worker_error_lock = asyncio.Lock() + + async def _upsert_worker() -> None: + nonlocal worker_error, worker_error_tb + while True: + batch = await upsert_queue.get() + if batch is None: + upsert_queue.task_done() + break + logger.info( + "Upserting %d points into Qdrant collection %s using batches of up to %d", + len(batch), collection_name, - batch, - retry_queue=qdrant_retry_queue, + _qdrant_batch_size, ) - ) - upsert_tasks.add(task) - if len(upsert_tasks) >= max_concurrent_upserts: - done, _ = await asyncio.wait( - upsert_tasks, return_when=asyncio.FIRST_COMPLETED - ) - for finished in done: - upsert_tasks.discard(finished) - finished.result() + try: + await _upsert_in_batches( + client, + collection_name, + batch, + retry_queue=qdrant_retry_queue, + ) + except Exception as exc: # defensive guard + async with worker_error_lock: + if worker_error is None: + worker_error = exc + worker_error_tb = exc.__traceback__ + logger.exception("Unexpected error upserting batch") + finally: + upsert_queue.task_done() + if batch is not None: + upsert_capacity.release() + + upsert_workers = [ + asyncio.create_task(_upsert_worker()) for _ in range(max_concurrent_upserts) + ] async for item in item_iter: items.append(item) @@ -1048,19 +1065,40 @@ async def _schedule_upsert(batch: List[models.PointStruct]) -> None: if len(points_buffer) >= upsert_buffer_size: batch = list(points_buffer) points_buffer.clear() - await _schedule_upsert(batch) + batches_enqueued += 1 + await upsert_capacity.acquire() + try: + await upsert_queue.put(batch) + except BaseException: + upsert_capacity.release() + raise logger.info("Loaded %d items", len(items)) if points_buffer: batch = list(points_buffer) points_buffer.clear() - await _schedule_upsert(batch) + batches_enqueued += 1 + await upsert_capacity.acquire() + try: + await upsert_queue.put(batch) + except BaseException: + upsert_capacity.release() + raise - if upsert_tasks: - await asyncio.gather(*upsert_tasks) - else: - logger.info("No points to upsert") + try: + await upsert_queue.join() + if batches_enqueued == 0: + logger.info("No points to upsert") + finally: + for _ in range(max_concurrent_upserts): + await upsert_queue.put(None) + await asyncio.gather(*upsert_workers) + + if worker_error is not None: + if worker_error_tb is not None: + raise worker_error.with_traceback(worker_error_tb) + raise worker_error await _process_qdrant_retry_queue(client, collection_name, qdrant_retry_queue) diff --git a/pyproject.toml b/pyproject.toml index 713a39b..4f54423 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mcp-plex" -version = "0.26.53" +version = "0.26.57" description = "Plex-Oriented Model Context Protocol Server" requires-python = ">=3.11,<3.13" diff --git a/tests/test_loader_logging.py b/tests/test_loader_logging.py index 8ee35bb..53de66e 100644 --- a/tests/test_loader_logging.py +++ b/tests/test_loader_logging.py @@ -35,6 +35,7 @@ def test_run_logs_upsert(monkeypatch, caplog): asyncio.run(loader.run(None, None, None, sample_dir, None, None)) assert "Loaded 2 items" in caplog.text assert "Upserting 2 points" in caplog.text + assert "using batches of up to" in caplog.text def test_run_logs_no_points(monkeypatch, caplog): @@ -74,6 +75,15 @@ def test_run_limits_concurrent_upserts(monkeypatch): concurrency = {"current": 0, "max": 0} started = asyncio.Queue() release_queue = asyncio.Queue() + third_requested = asyncio.Event() + + base_items = list(loader._load_from_sample(sample_dir)) + + async def fake_iter(sample_dir): + for idx, item in enumerate(base_items + base_items[:1]): + if idx == 2: + third_requested.set() + yield item async def fake_upsert(client, collection_name, points, **kwargs): concurrency["current"] += 1 @@ -83,20 +93,25 @@ async def fake_upsert(client, collection_name, points, **kwargs): concurrency["current"] -= 1 monkeypatch.setattr(loader, "_upsert_in_batches", fake_upsert) + monkeypatch.setattr(loader, "_iter_from_sample", fake_iter) async def invoke(): run_task = asyncio.create_task( loader.run(None, None, None, sample_dir, None, None, upsert_buffer_size=1) ) - await started.get() - release_queue.put_nowait(None) - await started.get() - release_queue.put_nowait(None) + await asyncio.wait_for(started.get(), timeout=1) + assert not third_requested.is_set() + await release_queue.put(None) + await asyncio.wait_for(started.get(), timeout=1) + await release_queue.put(None) + await asyncio.wait_for(started.get(), timeout=1) + await release_queue.put(None) await run_task asyncio.run(invoke()) assert concurrency["max"] == 1 + assert third_requested.is_set() def test_run_ensures_collection_before_loading(monkeypatch): diff --git a/uv.lock b/uv.lock index 11c138b..a909efc 100644 --- a/uv.lock +++ b/uv.lock @@ -730,7 +730,7 @@ wheels = [ [[package]] name = "mcp-plex" -version = "0.26.53" +version = "0.26.57" source = { editable = "." } dependencies = [ { name = "fastapi" },