diff --git a/docker/pyproject.deps.toml b/docker/pyproject.deps.toml index c0eca25..abb03e5 100644 --- a/docker/pyproject.deps.toml +++ b/docker/pyproject.deps.toml @@ -1,6 +1,6 @@ [project] name = "mcp-plex" -version = "0.26.70" +version = "0.26.71" requires-python = ">=3.11,<3.13" dependencies = [ "fastmcp>=2.11.2", diff --git a/mcp_plex/loader/pipeline/orchestrator.py b/mcp_plex/loader/pipeline/orchestrator.py index 45363e7..8a51753 100644 --- a/mcp_plex/loader/pipeline/orchestrator.py +++ b/mcp_plex/loader/pipeline/orchestrator.py @@ -1 +1,177 @@ -"""Placeholder module for the loader pipeline.""" +"""Coordinating logic tying the loader pipeline stages together.""" + +from __future__ import annotations + +import asyncio +import inspect +import logging +from dataclasses import dataclass +from typing import Awaitable, Callable + +from .channels import IngestQueue, PersistenceQueue + +LOGGER = logging.getLogger("mcp_plex.loader.orchestrator") + + +@dataclass(frozen=True, slots=True) +class _StageSpec: + """Descriptor for a running pipeline stage.""" + + role: str + worker_id: int | None = None + + +class _StageFailure(Exception): + """Wrapper exception capturing the originating stage failure.""" + + def __init__(self, spec: _StageSpec, error: BaseException) -> None: + super().__init__(str(error)) + self.spec = spec + self.error = error + + +class LoaderOrchestrator: + """Run the ingestion, enrichment, and persistence stages with supervision.""" + + def __init__( + self, + *, + ingestion_stage: object, + enrichment_stage: object, + persistence_stage: object, + ingest_queue: IngestQueue, + persistence_queue: PersistenceQueue, + persistence_worker_count: int = 1, + logger: logging.Logger | None = None, + ) -> None: + if persistence_worker_count <= 0: + raise ValueError("persistence_worker_count must be positive") + + self._ingestion_stage = ingestion_stage + self._enrichment_stage = enrichment_stage + self._persistence_stage = persistence_stage + self._ingest_queue = ingest_queue + self._persistence_queue = persistence_queue + self._persistence_worker_count = int(persistence_worker_count) + self._logger = logger or LOGGER + + async def run(self) -> None: + """Execute the configured pipeline stages concurrently.""" + + try: + async with asyncio.TaskGroup() as group: + group.create_task( + self._run_stage( + _StageSpec(role="ingestion"), + getattr(self._ingestion_stage, "run"), + ) + ) + group.create_task( + self._run_stage( + _StageSpec(role="enrichment"), + getattr(self._enrichment_stage, "run"), + ) + ) + persistence_runner = getattr(self._persistence_stage, "run") + for worker_id in range(self._persistence_worker_count): + group.create_task( + self._run_stage( + _StageSpec(role="persistence", worker_id=worker_id), + persistence_runner, + worker_id, + ) + ) + except* _StageFailure as exc_group: + failures = list(exc_group.exceptions) + await self._handle_failures(failures) + # Re-raise the first underlying error after cleanup so callers see the + # original exception rather than the wrapper. + raise failures[0].error + + async def _run_stage( + self, + spec: _StageSpec, + runner: Callable[..., Awaitable[object] | object], + *args: object, + ) -> None: + """Execute *runner* and wrap unexpected exceptions with stage metadata.""" + + try: + result = runner(*args) + if inspect.isawaitable(result): + await result + except asyncio.CancelledError: + raise + except BaseException as exc: + raise _StageFailure(spec, exc) from exc + + async def _handle_failures(self, failures: list[_StageFailure]) -> None: + """Log stage-specific failures and drain queues during cancellation.""" + + if not failures: + return + + roles = {failure.spec.role for failure in failures} + if "ingestion" in roles: + self._logger.warning( + "Ingestion stage failed; cancelling enrichment and persistence tasks." + ) + else: + self._logger.warning( + "Downstream stage failed; cancelling ingestion and related tasks." + ) + + for failure in failures: + stage_name = self._describe_stage(failure.spec) + self._logger.error( + "%s failed: %s", + stage_name, + failure.error, + exc_info=failure.error, + ) + + drained_ingest = self._drain_queue(self._ingest_queue) + drained_persist = self._drain_queue(self._persistence_queue) + if drained_ingest: + self._logger.debug( + "Drained %d item(s) from the ingest queue during cancellation.", + drained_ingest, + ) + if drained_persist: + self._logger.debug( + "Drained %d item(s) from the persistence queue during cancellation.", + drained_persist, + ) + + # Yield to the event loop so cancelled tasks can finish cleanup before the + # caller observes the exception. This mirrors the behaviour expected by + # the stage-specific tests which verify cancellation side-effects. + await asyncio.sleep(0) + + def _drain_queue(self, queue: asyncio.Queue[object]) -> int: + """Remove any queued items so cancellation does not leave stale work.""" + + drained = 0 + while True: + try: + queue.get_nowait() + except asyncio.QueueEmpty: + break + else: + drained += 1 + try: + queue.task_done() + except ValueError: # Queue.join() not in use; ignore bookkeeping. + pass + return drained + + def _describe_stage(self, spec: _StageSpec) -> str: + """Return a human-friendly name for *spec*.""" + + role = spec.role.capitalize() + if spec.worker_id is None: + return f"{role} stage" + return f"{role} stage (worker {spec.worker_id})" + + +__all__ = ["LoaderOrchestrator"] diff --git a/pyproject.toml b/pyproject.toml index 41b4c70..8b9f01e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mcp-plex" -version = "0.26.70" +version = "0.26.71" description = "Plex-Oriented Model Context Protocol Server" requires-python = ">=3.11,<3.13" diff --git a/tests/test_loader_orchestrator.py b/tests/test_loader_orchestrator.py new file mode 100644 index 0000000..42e24a4 --- /dev/null +++ b/tests/test_loader_orchestrator.py @@ -0,0 +1,166 @@ +import asyncio +import logging + +import pytest + +from mcp_plex.loader.pipeline.orchestrator import LoaderOrchestrator + + +class FailingIngestionStage: + def __init__(self, queue: asyncio.Queue[object]) -> None: + self.queue = queue + + async def run(self) -> None: + await self.queue.put("batch-1") + raise RuntimeError("ingestion boom") + + +class BlockingEnrichmentStage: + def __init__(self) -> None: + self.cancelled = asyncio.Event() + self._blocker = asyncio.Event() + + async def run(self) -> None: + try: + await self._blocker.wait() + except asyncio.CancelledError: + self.cancelled.set() + raise + + +class BlockingPersistenceStage: + def __init__(self, queue: asyncio.Queue[object]) -> None: + self.queue = queue + self.cancelled = asyncio.Event() + + async def run(self, worker_id: int) -> None: + try: + while True: + await self.queue.get() + except asyncio.CancelledError: + self.cancelled.set() + raise + + +class BlockingIngestionStage: + def __init__(self, queue: asyncio.Queue[object]) -> None: + self.queue = queue + self.cancelled = asyncio.Event() + self._blocker = asyncio.Event() + + async def run(self) -> None: + try: + await self.queue.put("batch-1") + await self.queue.put("batch-2") + await self._blocker.wait() + except asyncio.CancelledError: + self.cancelled.set() + raise + + +class SingleBatchEnrichmentStage: + def __init__( + self, + ingest_queue: asyncio.Queue[object], + persistence_queue: asyncio.Queue[object], + ) -> None: + self.ingest_queue = ingest_queue + self.persistence_queue = persistence_queue + self.cancelled = asyncio.Event() + self._blocker = asyncio.Event() + + async def run(self) -> None: + try: + payload = await self.ingest_queue.get() + self.ingest_queue.task_done() + await self.persistence_queue.put(payload) + await self._blocker.wait() + except asyncio.CancelledError: + self.cancelled.set() + raise + + +class FailingPersistenceStage: + def __init__(self, queue: asyncio.Queue[object]) -> None: + self.queue = queue + + async def run(self, worker_id: int) -> None: + payload = await self.queue.get() + self.queue.task_done() + raise RuntimeError(f"persistence boom {worker_id}: {payload}") + + +def _build_orchestrator( + *, + ingestion_stage: object, + enrichment_stage: object, + persistence_stage: object, + ingest_queue: asyncio.Queue[object], + persistence_queue: asyncio.Queue[object], +) -> LoaderOrchestrator: + return LoaderOrchestrator( + ingestion_stage=ingestion_stage, + enrichment_stage=enrichment_stage, + persistence_stage=persistence_stage, + ingest_queue=ingest_queue, + persistence_queue=persistence_queue, + persistence_worker_count=1, + ) + + +def test_ingestion_failure_cancels_downstream(caplog: pytest.LogCaptureFixture) -> None: + ingest_queue: asyncio.Queue[object] = asyncio.Queue() + persistence_queue: asyncio.Queue[object] = asyncio.Queue() + ingestion_stage = FailingIngestionStage(ingest_queue) + enrichment_stage = BlockingEnrichmentStage() + persistence_stage = BlockingPersistenceStage(persistence_queue) + orchestrator = _build_orchestrator( + ingestion_stage=ingestion_stage, + enrichment_stage=enrichment_stage, + persistence_stage=persistence_stage, + ingest_queue=ingest_queue, + persistence_queue=persistence_queue, + ) + + async def _run() -> None: + await orchestrator.run() + + with caplog.at_level(logging.ERROR, logger="mcp_plex.loader.orchestrator"): + with pytest.raises(RuntimeError, match="ingestion boom"): + asyncio.run(_run()) + + assert enrichment_stage.cancelled.is_set() + assert persistence_stage.cancelled.is_set() + assert ingest_queue.qsize() == 0 + assert persistence_queue.qsize() == 0 + error_messages = [record.getMessage() for record in caplog.records] + assert any("Ingestion stage failed" in message for message in error_messages) + + +def test_persistence_failure_cancels_upstream(caplog: pytest.LogCaptureFixture) -> None: + ingest_queue: asyncio.Queue[object] = asyncio.Queue() + persistence_queue: asyncio.Queue[object] = asyncio.Queue() + ingestion_stage = BlockingIngestionStage(ingest_queue) + enrichment_stage = SingleBatchEnrichmentStage(ingest_queue, persistence_queue) + persistence_stage = FailingPersistenceStage(persistence_queue) + orchestrator = _build_orchestrator( + ingestion_stage=ingestion_stage, + enrichment_stage=enrichment_stage, + persistence_stage=persistence_stage, + ingest_queue=ingest_queue, + persistence_queue=persistence_queue, + ) + + async def _run() -> None: + await orchestrator.run() + + with caplog.at_level(logging.ERROR, logger="mcp_plex.loader.orchestrator"): + with pytest.raises(RuntimeError, match="persistence boom"): + asyncio.run(_run()) + + assert ingestion_stage.cancelled.is_set() + assert enrichment_stage.cancelled.is_set() + assert ingest_queue.qsize() == 0 + assert persistence_queue.qsize() == 0 + error_messages = [record.getMessage() for record in caplog.records] + assert any("Persistence stage" in message for message in error_messages) diff --git a/uv.lock b/uv.lock index 23b6113..9d0386a 100644 --- a/uv.lock +++ b/uv.lock @@ -730,7 +730,7 @@ wheels = [ [[package]] name = "mcp-plex" -version = "0.26.70" +version = "0.26.71" source = { editable = "." } dependencies = [ { name = "fastapi" },