Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/pyproject.deps.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "mcp-plex"
version = "0.26.70"
version = "0.26.71"
requires-python = ">=3.11,<3.13"
dependencies = [
"fastmcp>=2.11.2",
Expand Down
178 changes: 177 additions & 1 deletion mcp_plex/loader/pipeline/orchestrator.py
Original file line number Diff line number Diff line change
@@ -1 +1,177 @@
"""Placeholder module for the loader pipeline."""
"""Coordinating logic tying the loader pipeline stages together."""

from __future__ import annotations

import asyncio
import inspect
import logging
from dataclasses import dataclass
from typing import Awaitable, Callable

from .channels import IngestQueue, PersistenceQueue

LOGGER = logging.getLogger("mcp_plex.loader.orchestrator")


@dataclass(frozen=True, slots=True)
class _StageSpec:
"""Descriptor for a running pipeline stage."""

role: str
worker_id: int | None = None


class _StageFailure(Exception):
"""Wrapper exception capturing the originating stage failure."""

def __init__(self, spec: _StageSpec, error: BaseException) -> None:
super().__init__(str(error))
self.spec = spec
self.error = error


class LoaderOrchestrator:
"""Run the ingestion, enrichment, and persistence stages with supervision."""

def __init__(
self,
*,
ingestion_stage: object,
enrichment_stage: object,
persistence_stage: object,
ingest_queue: IngestQueue,
persistence_queue: PersistenceQueue,
persistence_worker_count: int = 1,
logger: logging.Logger | None = None,
) -> None:
if persistence_worker_count <= 0:
raise ValueError("persistence_worker_count must be positive")

self._ingestion_stage = ingestion_stage
self._enrichment_stage = enrichment_stage
self._persistence_stage = persistence_stage
self._ingest_queue = ingest_queue
self._persistence_queue = persistence_queue
self._persistence_worker_count = int(persistence_worker_count)
self._logger = logger or LOGGER

async def run(self) -> None:
"""Execute the configured pipeline stages concurrently."""

try:
async with asyncio.TaskGroup() as group:
group.create_task(
self._run_stage(
_StageSpec(role="ingestion"),
getattr(self._ingestion_stage, "run"),
)
)
group.create_task(
self._run_stage(
_StageSpec(role="enrichment"),
getattr(self._enrichment_stage, "run"),
)
)
persistence_runner = getattr(self._persistence_stage, "run")
for worker_id in range(self._persistence_worker_count):
group.create_task(
self._run_stage(
_StageSpec(role="persistence", worker_id=worker_id),
persistence_runner,
worker_id,
)
)
except* _StageFailure as exc_group:
failures = list(exc_group.exceptions)
await self._handle_failures(failures)
# Re-raise the first underlying error after cleanup so callers see the
# original exception rather than the wrapper.
raise failures[0].error

async def _run_stage(
self,
spec: _StageSpec,
runner: Callable[..., Awaitable[object] | object],
*args: object,
) -> None:
"""Execute *runner* and wrap unexpected exceptions with stage metadata."""

try:
result = runner(*args)
if inspect.isawaitable(result):
await result
except asyncio.CancelledError:
raise
except BaseException as exc:
raise _StageFailure(spec, exc) from exc

async def _handle_failures(self, failures: list[_StageFailure]) -> None:
"""Log stage-specific failures and drain queues during cancellation."""

if not failures:
return

roles = {failure.spec.role for failure in failures}
if "ingestion" in roles:
self._logger.warning(
"Ingestion stage failed; cancelling enrichment and persistence tasks."
)
else:
self._logger.warning(
"Downstream stage failed; cancelling ingestion and related tasks."
)

for failure in failures:
stage_name = self._describe_stage(failure.spec)
self._logger.error(
"%s failed: %s",
stage_name,
failure.error,
exc_info=failure.error,
)

drained_ingest = self._drain_queue(self._ingest_queue)
drained_persist = self._drain_queue(self._persistence_queue)
if drained_ingest:
self._logger.debug(
"Drained %d item(s) from the ingest queue during cancellation.",
drained_ingest,
)
if drained_persist:
self._logger.debug(
"Drained %d item(s) from the persistence queue during cancellation.",
drained_persist,
)

# Yield to the event loop so cancelled tasks can finish cleanup before the
# caller observes the exception. This mirrors the behaviour expected by
# the stage-specific tests which verify cancellation side-effects.
await asyncio.sleep(0)

def _drain_queue(self, queue: asyncio.Queue[object]) -> int:
"""Remove any queued items so cancellation does not leave stale work."""

drained = 0
while True:
try:
queue.get_nowait()
except asyncio.QueueEmpty:
break
else:
drained += 1
try:
queue.task_done()
except ValueError: # Queue.join() not in use; ignore bookkeeping.
pass
return drained

def _describe_stage(self, spec: _StageSpec) -> str:
"""Return a human-friendly name for *spec*."""

role = spec.role.capitalize()
if spec.worker_id is None:
return f"{role} stage"
return f"{role} stage (worker {spec.worker_id})"


__all__ = ["LoaderOrchestrator"]
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "mcp-plex"
version = "0.26.70"
version = "0.26.71"

description = "Plex-Oriented Model Context Protocol Server"
requires-python = ">=3.11,<3.13"
Expand Down
166 changes: 166 additions & 0 deletions tests/test_loader_orchestrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import asyncio
import logging

import pytest

from mcp_plex.loader.pipeline.orchestrator import LoaderOrchestrator


class FailingIngestionStage:
def __init__(self, queue: asyncio.Queue[object]) -> None:
self.queue = queue

async def run(self) -> None:
await self.queue.put("batch-1")
raise RuntimeError("ingestion boom")


class BlockingEnrichmentStage:
def __init__(self) -> None:
self.cancelled = asyncio.Event()
self._blocker = asyncio.Event()

async def run(self) -> None:
try:
await self._blocker.wait()
except asyncio.CancelledError:
self.cancelled.set()
raise


class BlockingPersistenceStage:
def __init__(self, queue: asyncio.Queue[object]) -> None:
self.queue = queue
self.cancelled = asyncio.Event()

async def run(self, worker_id: int) -> None:
try:
while True:
await self.queue.get()
except asyncio.CancelledError:
self.cancelled.set()
raise


class BlockingIngestionStage:
def __init__(self, queue: asyncio.Queue[object]) -> None:
self.queue = queue
self.cancelled = asyncio.Event()
self._blocker = asyncio.Event()

async def run(self) -> None:
try:
await self.queue.put("batch-1")
await self.queue.put("batch-2")
await self._blocker.wait()
except asyncio.CancelledError:
self.cancelled.set()
raise


class SingleBatchEnrichmentStage:
def __init__(
self,
ingest_queue: asyncio.Queue[object],
persistence_queue: asyncio.Queue[object],
) -> None:
self.ingest_queue = ingest_queue
self.persistence_queue = persistence_queue
self.cancelled = asyncio.Event()
self._blocker = asyncio.Event()

async def run(self) -> None:
try:
payload = await self.ingest_queue.get()
self.ingest_queue.task_done()
await self.persistence_queue.put(payload)
await self._blocker.wait()
except asyncio.CancelledError:
self.cancelled.set()
raise


class FailingPersistenceStage:
def __init__(self, queue: asyncio.Queue[object]) -> None:
self.queue = queue

async def run(self, worker_id: int) -> None:
payload = await self.queue.get()
self.queue.task_done()
raise RuntimeError(f"persistence boom {worker_id}: {payload}")


def _build_orchestrator(
*,
ingestion_stage: object,
enrichment_stage: object,
persistence_stage: object,
ingest_queue: asyncio.Queue[object],
persistence_queue: asyncio.Queue[object],
) -> LoaderOrchestrator:
return LoaderOrchestrator(
ingestion_stage=ingestion_stage,
enrichment_stage=enrichment_stage,
persistence_stage=persistence_stage,
ingest_queue=ingest_queue,
persistence_queue=persistence_queue,
persistence_worker_count=1,
)


def test_ingestion_failure_cancels_downstream(caplog: pytest.LogCaptureFixture) -> None:
ingest_queue: asyncio.Queue[object] = asyncio.Queue()
persistence_queue: asyncio.Queue[object] = asyncio.Queue()
ingestion_stage = FailingIngestionStage(ingest_queue)
enrichment_stage = BlockingEnrichmentStage()
persistence_stage = BlockingPersistenceStage(persistence_queue)
orchestrator = _build_orchestrator(
ingestion_stage=ingestion_stage,
enrichment_stage=enrichment_stage,
persistence_stage=persistence_stage,
ingest_queue=ingest_queue,
persistence_queue=persistence_queue,
)

async def _run() -> None:
await orchestrator.run()

with caplog.at_level(logging.ERROR, logger="mcp_plex.loader.orchestrator"):
with pytest.raises(RuntimeError, match="ingestion boom"):
asyncio.run(_run())

assert enrichment_stage.cancelled.is_set()
assert persistence_stage.cancelled.is_set()
assert ingest_queue.qsize() == 0
assert persistence_queue.qsize() == 0
error_messages = [record.getMessage() for record in caplog.records]
assert any("Ingestion stage failed" in message for message in error_messages)


def test_persistence_failure_cancels_upstream(caplog: pytest.LogCaptureFixture) -> None:
ingest_queue: asyncio.Queue[object] = asyncio.Queue()
persistence_queue: asyncio.Queue[object] = asyncio.Queue()
ingestion_stage = BlockingIngestionStage(ingest_queue)
enrichment_stage = SingleBatchEnrichmentStage(ingest_queue, persistence_queue)
persistence_stage = FailingPersistenceStage(persistence_queue)
orchestrator = _build_orchestrator(
ingestion_stage=ingestion_stage,
enrichment_stage=enrichment_stage,
persistence_stage=persistence_stage,
ingest_queue=ingest_queue,
persistence_queue=persistence_queue,
)

async def _run() -> None:
await orchestrator.run()

with caplog.at_level(logging.ERROR, logger="mcp_plex.loader.orchestrator"):
with pytest.raises(RuntimeError, match="persistence boom"):
asyncio.run(_run())

assert ingestion_stage.cancelled.is_set()
assert enrichment_stage.cancelled.is_set()
assert ingest_queue.qsize() == 0
assert persistence_queue.qsize() == 0
error_messages = [record.getMessage() for record in caplog.records]
assert any("Persistence stage" in message for message in error_messages)
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.