diff --git a/docker/pyproject.deps.toml b/docker/pyproject.deps.toml index 64bbec1..1a12bf8 100644 --- a/docker/pyproject.deps.toml +++ b/docker/pyproject.deps.toml @@ -1,6 +1,6 @@ [project] name = "mcp-plex" -version = "0.26.68" +version = "0.26.69" requires-python = ">=3.11,<3.13" dependencies = [ "fastmcp>=2.11.2", diff --git a/mcp_plex/loader/__init__.py b/mcp_plex/loader/__init__.py index e037e31..1e658b4 100644 --- a/mcp_plex/loader/__init__.py +++ b/mcp_plex/loader/__init__.py @@ -37,6 +37,7 @@ chunk_sequence, require_positive, ) +from .pipeline.persistence import PersistenceStage as _PersistenceStage from ..common.types import ( AggregatedItem, ExternalIDs, @@ -985,6 +986,18 @@ def __init__( self._ingest_start = now self._enrich_start = now self._upsert_start = now + self._persistence_stage = _PersistenceStage( + client=self._client, + collection_name=self._collection_name, + dense_vector_name=self._dense_model_name, + sparse_vector_name=self._sparse_model_name, + persistence_queue=self._points_queue, + retry_queue=self._qdrant_retry_queue, + upsert_semaphore=self._upsert_capacity, + upsert_buffer_size=self._upsert_buffer_size, + upsert_fn=self._perform_upsert, + on_batch_complete=self._handle_upsert_batch, + ) @property def qdrant_retry_queue(self) -> asyncio.Queue[list[models.PointStruct]]: @@ -1009,7 +1022,7 @@ async def execute(self) -> None: for worker_id in range(self._enrichment_workers) ] upsert_tasks = [ - asyncio.create_task(self._upsert_worker(worker_id)) + asyncio.create_task(self._persistence_stage.run(worker_id)) for worker_id in range(self._max_concurrent_upserts) ] error: BaseException | None = None @@ -1297,48 +1310,30 @@ async def _emit_points(self, aggregated: Sequence[AggregatedItem]) -> None: build_point(item, self._dense_model_name, self._sparse_model_name) for item in aggregated ] - for chunk in _chunk_sequence(points, self._upsert_buffer_size): - batch = list(chunk) - if not batch: - continue - await self._upsert_capacity.acquire() - try: - await self._points_queue.put(batch) - except BaseException: - self._upsert_capacity.release() - raise + await self._persistence_stage.enqueue_points(points) - async def _upsert_worker(self, worker_id: int) -> None: - while True: - batch = await self._points_queue.get() - if batch is None: - self._points_queue.task_done() - break - logger.info( - "Upsert worker %d handling %d points (queue size=%d)", - worker_id, - len(batch), - self._points_queue.qsize(), - ) - try: - if self._upserted_points == 0: - self._upsert_start = time.perf_counter() - await _upsert_in_batches( - self._client, - self._collection_name, - batch, - retry_queue=self._qdrant_retry_queue, - ) - self._upserted_points += len(batch) - self._log_progress( - f"Upsert worker {worker_id}", - self._upserted_points, - self._upsert_start, - self._points_queue.qsize(), - ) - finally: - self._points_queue.task_done() - self._upsert_capacity.release() + async def _perform_upsert( + self, batch: Sequence[models.PointStruct] + ) -> None: + await _upsert_in_batches( + self._client, + self._collection_name, + list(batch), + retry_queue=self._qdrant_retry_queue, + ) + + def _handle_upsert_batch( + self, worker_id: int, batch_size: int, queue_size: int + ) -> None: + if self._upserted_points == 0: + self._upsert_start = time.perf_counter() + self._upserted_points += batch_size + self._log_progress( + f"Upsert worker {worker_id}", + self._upserted_points, + self._upsert_start, + queue_size, + ) async def run( plex_url: Optional[str], plex_token: Optional[str], diff --git a/mcp_plex/loader/pipeline/persistence.py b/mcp_plex/loader/pipeline/persistence.py index 539708f..4fc380b 100644 --- a/mcp_plex/loader/pipeline/persistence.py +++ b/mcp_plex/loader/pipeline/persistence.py @@ -4,9 +4,9 @@ import asyncio import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Sequence -from .channels import PersistenceQueue +from .channels import PersistenceQueue, chunk_sequence, require_positive if TYPE_CHECKING: # pragma: no cover - typing helpers only from qdrant_client import AsyncQdrantClient, models @@ -30,6 +30,9 @@ def __init__( persistence_queue: PersistenceQueue, retry_queue: asyncio.Queue[PersistencePayload], upsert_semaphore: asyncio.Semaphore, + upsert_buffer_size: int, + upsert_fn: Callable[[PersistencePayload], Awaitable[None]], + on_batch_complete: Callable[[int, int, int], None] | None = None, ) -> None: self._client = client self._collection_name = str(collection_name) @@ -38,6 +41,11 @@ def __init__( self._persistence_queue = persistence_queue self._retry_queue = retry_queue self._upsert_semaphore = upsert_semaphore + self._upsert_buffer_size = require_positive( + upsert_buffer_size, name="upsert_buffer_size" + ) + self._upsert_fn = upsert_fn + self._on_batch_complete = on_batch_complete self._logger = logging.getLogger("mcp_plex.loader.persistence") @property @@ -88,7 +96,32 @@ def upsert_semaphore(self) -> asyncio.Semaphore: return self._upsert_semaphore - async def run(self) -> None: + @property + def upsert_buffer_size(self) -> int: + """Maximum number of points per persistence batch.""" + + return self._upsert_buffer_size + + async def enqueue_points( + self, points: Sequence["models.PointStruct"] + ) -> None: + """Chunk *points* and place them on the persistence queue.""" + + if not points: + return + + for chunk in chunk_sequence(list(points), self._upsert_buffer_size): + batch = list(chunk) + if not batch: + continue + await self._upsert_semaphore.acquire() + try: + await self._persistence_queue.put(batch) + except BaseException: + self._upsert_semaphore.release() + raise + + async def run(self, worker_id: int) -> None: """Drain the persistence queue until a sentinel is received.""" while True: @@ -96,15 +129,24 @@ async def run(self) -> None: try: if payload is None: self._logger.debug( - "Persistence queue sentinel received; finishing placeholder run." + "Persistence queue sentinel received; finishing run for worker %d.", + worker_id, ) return - self._logger.debug( - "Placeholder persistence stage received batch with %d items.", + queue_size = self._persistence_queue.qsize() + self._logger.info( + "Upsert worker %d handling %d points (queue size=%d)", + worker_id, len(payload), + queue_size, ) + await self._upsert_fn(payload) + if self._on_batch_complete is not None: + self._on_batch_complete( + worker_id, len(payload), self._persistence_queue.qsize() + ) finally: self._persistence_queue.task_done() - - await asyncio.sleep(0) + if payload is not None: + self._upsert_semaphore.release() diff --git a/pyproject.toml b/pyproject.toml index 67c9bf6..94133ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mcp-plex" -version = "0.26.68" +version = "0.26.69" description = "Plex-Oriented Model Context Protocol Server" requires-python = ">=3.11,<3.13" diff --git a/tests/test_persistence_stage.py b/tests/test_persistence_stage.py index d7e0c54..a552c43 100644 --- a/tests/test_persistence_stage.py +++ b/tests/test_persistence_stage.py @@ -1,11 +1,20 @@ import asyncio +from typing import Any +import pytest + +from mcp_plex.loader import _upsert_in_batches from mcp_plex.loader.pipeline.channels import PersistenceQueue from mcp_plex.loader.pipeline.persistence import PersistenceStage class _FakeQdrantClient: - pass + async def upsert(self, *, collection_name: str, points: list[Any]) -> None: + raise NotImplementedError + + +async def _noop_upsert(_: list[Any]) -> None: + await asyncio.sleep(0) def test_persistence_stage_logger_name() -> None: @@ -23,6 +32,9 @@ async def scenario() -> str: persistence_queue=persistence_queue, retry_queue=retry_queue, upsert_semaphore=semaphore, + upsert_buffer_size=2, + upsert_fn=_noop_upsert, + on_batch_complete=None, ) return stage.logger.name @@ -32,7 +44,13 @@ async def scenario() -> str: def test_persistence_stage_holds_dependencies() -> None: - async def scenario() -> tuple[PersistenceStage, _FakeQdrantClient, PersistenceQueue, asyncio.Queue, asyncio.Semaphore]: + async def scenario() -> tuple[ + PersistenceStage, + _FakeQdrantClient, + PersistenceQueue, + asyncio.Queue, + asyncio.Semaphore, + ]: client = _FakeQdrantClient() persistence_queue: PersistenceQueue = asyncio.Queue() retry_queue: asyncio.Queue = asyncio.Queue() @@ -46,6 +64,9 @@ async def scenario() -> tuple[PersistenceStage, _FakeQdrantClient, PersistenceQu persistence_queue=persistence_queue, retry_queue=retry_queue, upsert_semaphore=semaphore, + upsert_buffer_size=3, + upsert_fn=_noop_upsert, + on_batch_complete=None, ) return stage, client, persistence_queue, retry_queue, semaphore @@ -59,3 +80,141 @@ async def scenario() -> tuple[PersistenceStage, _FakeQdrantClient, PersistenceQu assert stage.persistence_queue is persistence_queue assert stage.retry_queue is retry_queue assert stage.upsert_semaphore is semaphore + + +def test_persistence_stage_upserts_batches() -> None: + async def scenario() -> tuple[list[list[int]], list[tuple[int, int, int]], int]: + persistence_queue: PersistenceQueue = asyncio.Queue() + retry_queue: asyncio.Queue = asyncio.Queue() + semaphore = asyncio.Semaphore(2) + + processed: list[list[int]] = [] + + async def fake_upsert(batch: list[int]) -> None: + processed.append(list(batch)) + + completions: list[tuple[int, int, int]] = [] + + def on_batch_complete(worker_id: int, batch_size: int, queue_size: int) -> None: + completions.append((worker_id, batch_size, queue_size)) + + stage = PersistenceStage( + client=_FakeQdrantClient(), + collection_name="media-items", + dense_vector_name="dense", + sparse_vector_name="sparse", + persistence_queue=persistence_queue, + retry_queue=retry_queue, + upsert_semaphore=semaphore, + upsert_buffer_size=2, + upsert_fn=fake_upsert, + on_batch_complete=on_batch_complete, + ) + + workers = [asyncio.create_task(stage.run(worker_id)) for worker_id in range(2)] + + await stage.enqueue_points([1, 2, 3]) + await persistence_queue.join() + + for _ in workers: + await persistence_queue.put(None) + + await asyncio.gather(*workers) + + return processed, completions, semaphore._value # type: ignore[attr-defined] + + processed, completions, semaphore_value = asyncio.run(scenario()) + + assert {tuple(batch) for batch in processed} == {(1, 2), (3,)} + assert sorted(batch_size for _, batch_size, _ in completions) == [1, 2] + assert completions[-1][2] == 0 + assert semaphore_value == 2 + + +def test_persistence_stage_populates_retry_queue_on_failure() -> None: + async def scenario() -> list[list[int]]: + persistence_queue: PersistenceQueue = asyncio.Queue() + retry_queue: asyncio.Queue[list[list[int]]] = asyncio.Queue() + semaphore = asyncio.Semaphore(1) + + class _FailingClient: + async def upsert( + self, *, collection_name: str, points: list[list[int]] + ) -> None: + raise RuntimeError("boom") + + async def upsert_fn(batch: list[list[int]]) -> None: + await _upsert_in_batches( + _FailingClient(), + "media-items", + batch, + retry_queue=retry_queue, + ) + + stage = PersistenceStage( + client=_FakeQdrantClient(), + collection_name="media-items", + dense_vector_name="dense", + sparse_vector_name="sparse", + persistence_queue=persistence_queue, + retry_queue=retry_queue, + upsert_semaphore=semaphore, + upsert_buffer_size=10, + upsert_fn=upsert_fn, + on_batch_complete=None, + ) + + worker = asyncio.create_task(stage.run(0)) + + await stage.enqueue_points([[1, 2]]) + await persistence_queue.join() + await persistence_queue.put(None) + await asyncio.gather(worker) + + failures: list[list[int]] = [] + while not retry_queue.empty(): + failures.append(await retry_queue.get()) + + return failures + + failures = asyncio.run(scenario()) + + assert failures == [[[1, 2]]] + + +def test_persistence_stage_releases_semaphore_on_upsert_error() -> None: + async def scenario() -> int: + persistence_queue: PersistenceQueue = asyncio.Queue() + retry_queue: asyncio.Queue = asyncio.Queue() + semaphore = asyncio.Semaphore(1) + + async def failing_upsert(batch: list[int]) -> None: + raise RuntimeError("boom") + + stage = PersistenceStage( + client=_FakeQdrantClient(), + collection_name="media-items", + dense_vector_name="dense", + sparse_vector_name="sparse", + persistence_queue=persistence_queue, + retry_queue=retry_queue, + upsert_semaphore=semaphore, + upsert_buffer_size=5, + upsert_fn=failing_upsert, + on_batch_complete=None, + ) + + worker = asyncio.create_task(stage.run(0)) + + await stage.enqueue_points([1]) + + with pytest.raises(RuntimeError): + await worker + + await asyncio.wait_for(persistence_queue.join(), timeout=1) + + return semaphore._value # type: ignore[attr-defined] + + semaphore_value = asyncio.run(scenario()) + + assert semaphore_value == 1 diff --git a/uv.lock b/uv.lock index c667106..8f1fd6e 100644 --- a/uv.lock +++ b/uv.lock @@ -730,7 +730,7 @@ wheels = [ [[package]] name = "mcp-plex" -version = "0.26.68" +version = "0.26.69" source = { editable = "." } dependencies = [ { name = "fastapi" },