Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 79 additions & 76 deletions mcp_plex/loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,86 +364,89 @@ async def run(
https=qdrant_https,
prefer_grpc=qdrant_prefer_grpc,
)
collection_name = "media-items"
await _ensure_collection(
client,
collection_name,
dense_size=dense_size,
dense_distance=dense_distance,
)

items: list[AggregatedItem]
if sample_dir is not None:
logger.info("Loading sample data from %s", sample_dir)
sample_items = samples._load_from_sample(sample_dir)
orchestrator, items, qdrant_retry_queue = _build_loader_orchestrator(
client=client,
collection_name=collection_name,
dense_model_name=dense_model_name,
sparse_model_name=sparse_model_name,
tmdb_api_key=None,
sample_items=sample_items,
plex_server=None,
plex_chunk_size=plex_chunk_size,
enrichment_batch_size=enrichment_batch_size,
enrichment_workers=enrichment_workers,
upsert_buffer_size=upsert_buffer_size,
max_concurrent_upserts=max_concurrent_upserts,
imdb_config=IMDbRuntimeConfig(
cache=imdb_config.cache,
max_retries=imdb_config.max_retries,
backoff=imdb_config.backoff,
retry_queue=IMDbRetryQueue(),
requests_per_window=imdb_config.requests_per_window,
window_seconds=imdb_config.window_seconds,
),
qdrant_config=qdrant_config,
try:
collection_name = "media-items"
await _ensure_collection(
client,
collection_name,
dense_size=dense_size,
dense_distance=dense_distance,
)
logger.info("Starting staged loader (sample mode)")
await orchestrator.run()
else:
if PlexServer is None:
raise RuntimeError("plexapi is required for live loading")
if not plex_url or not plex_token:
raise RuntimeError("PLEX_URL and PLEX_TOKEN must be provided")
if not tmdb_api_key:
raise RuntimeError("TMDB_API_KEY must be provided")
logger.info("Loading data from Plex server %s", plex_url)
server = PlexServer(plex_url, plex_token)
orchestrator, items, qdrant_retry_queue = _build_loader_orchestrator(
client=client,
collection_name=collection_name,
dense_model_name=dense_model_name,
sparse_model_name=sparse_model_name,
tmdb_api_key=tmdb_api_key,
sample_items=None,
plex_server=server,
plex_chunk_size=plex_chunk_size,
enrichment_batch_size=enrichment_batch_size,
enrichment_workers=enrichment_workers,
upsert_buffer_size=upsert_buffer_size,
max_concurrent_upserts=max_concurrent_upserts,
imdb_config=imdb_config,
qdrant_config=qdrant_config,
)
logger.info("Starting staged loader (Plex mode)")
await orchestrator.run()
logger.info("Loaded %d items", len(items))
if not items:
logger.info("No points to upsert")

await _process_qdrant_retry_queue(
client,
collection_name,
qdrant_retry_queue,
config=qdrant_config,
)
items: list[AggregatedItem]
if sample_dir is not None:
logger.info("Loading sample data from %s", sample_dir)
sample_items = samples._load_from_sample(sample_dir)
orchestrator, items, qdrant_retry_queue = _build_loader_orchestrator(
client=client,
collection_name=collection_name,
dense_model_name=dense_model_name,
sparse_model_name=sparse_model_name,
tmdb_api_key=None,
sample_items=sample_items,
plex_server=None,
plex_chunk_size=plex_chunk_size,
enrichment_batch_size=enrichment_batch_size,
enrichment_workers=enrichment_workers,
upsert_buffer_size=upsert_buffer_size,
max_concurrent_upserts=max_concurrent_upserts,
imdb_config=IMDbRuntimeConfig(
cache=imdb_config.cache,
max_retries=imdb_config.max_retries,
backoff=imdb_config.backoff,
retry_queue=IMDbRetryQueue(),
requests_per_window=imdb_config.requests_per_window,
window_seconds=imdb_config.window_seconds,
),
qdrant_config=qdrant_config,
)
logger.info("Starting staged loader (sample mode)")
await orchestrator.run()
else:
if PlexServer is None:
raise RuntimeError("plexapi is required for live loading")
if not plex_url or not plex_token:
raise RuntimeError("PLEX_URL and PLEX_TOKEN must be provided")
if not tmdb_api_key:
raise RuntimeError("TMDB_API_KEY must be provided")
logger.info("Loading data from Plex server %s", plex_url)
server = PlexServer(plex_url, plex_token)
orchestrator, items, qdrant_retry_queue = _build_loader_orchestrator(
client=client,
collection_name=collection_name,
dense_model_name=dense_model_name,
sparse_model_name=sparse_model_name,
tmdb_api_key=tmdb_api_key,
sample_items=None,
plex_server=server,
plex_chunk_size=plex_chunk_size,
enrichment_batch_size=enrichment_batch_size,
enrichment_workers=enrichment_workers,
upsert_buffer_size=upsert_buffer_size,
max_concurrent_upserts=max_concurrent_upserts,
imdb_config=imdb_config,
qdrant_config=qdrant_config,
)
logger.info("Starting staged loader (Plex mode)")
await orchestrator.run()
logger.info("Loaded %d items", len(items))
if not items:
logger.info("No points to upsert")

await _process_qdrant_retry_queue(
client,
collection_name,
qdrant_retry_queue,
config=qdrant_config,
)

if imdb_queue_path:
_persist_imdb_retry_queue(imdb_queue_path, imdb_config.retry_queue)
if imdb_queue_path:
_persist_imdb_retry_queue(imdb_queue_path, imdb_config.retry_queue)

json.dump([item.model_dump(mode="json") for item in items], fp=sys.stdout, indent=2)
sys.stdout.write("\n")
json.dump([item.model_dump(mode="json") for item in items], fp=sys.stdout, indent=2)
sys.stdout.write("\n")
finally:
await client.close()


async def load_media(
Expand Down
27 changes: 22 additions & 5 deletions tests/test_loader_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class CaptureClient(AsyncQdrantClient):
captured_points: list[models.PointStruct] = []
upsert_calls: int = 0
created_indexes: list[tuple[str, Any]] = []
close_calls: int = 0

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -44,6 +45,10 @@ async def create_payload_index(
wait=wait,
)

async def close(self) -> None:
CaptureClient.close_calls += 1
await super().close()


async def _run_loader(sample_dir: Path, **kwargs) -> None:
await loader.run(
Expand All @@ -62,6 +67,7 @@ def test_run_writes_points(monkeypatch):
CaptureClient.captured_points = []
CaptureClient.upsert_calls = 0
CaptureClient.created_indexes = []
CaptureClient.close_calls = 0
sample_dir = Path(__file__).resolve().parents[1] / "sample-data"
asyncio.run(_run_loader(sample_dir))
client = CaptureClient.instance
Expand All @@ -70,9 +76,6 @@ def test_run_writes_points(monkeypatch):
assert index_map.get("show_title") == models.PayloadSchemaType.KEYWORD
assert index_map.get("season_number") == models.PayloadSchemaType.INTEGER
assert index_map.get("episode_number") == models.PayloadSchemaType.INTEGER
points, _ = asyncio.run(client.scroll("media-items", limit=10, with_payload=True))
assert len(points) == 2
assert all("title" in p.payload and "type" in p.payload for p in points)
captured = CaptureClient.captured_points
assert len(captured) == 2
assert all(isinstance(p.vector["dense"], models.Document) for p in captured)
Expand All @@ -85,7 +88,7 @@ def test_run_writes_points(monkeypatch):
texts = [p.vector["dense"].text for p in captured]
assert any("Directed by" in t for t in texts)
assert any("Starring" in t for t in texts)
movie_point = next(p for p in points if p.payload["type"] == "movie")
movie_point = next(p for p in captured if p.payload["type"] == "movie")
assert "directors" in movie_point.payload and "Guy Ritchie" in movie_point.payload["directors"]
assert "writers" in movie_point.payload and movie_point.payload["writers"]
assert "genres" in movie_point.payload and movie_point.payload["genres"]
Expand All @@ -94,7 +97,7 @@ def test_run_writes_points(monkeypatch):
assert movie_point.payload.get("plot")
assert movie_point.payload.get("tagline")
assert movie_point.payload.get("reviews")
episode_point = next(p for p in points if p.payload["type"] == "episode")
episode_point = next(p for p in captured if p.payload["type"] == "episode")
assert episode_point.payload.get("show_title") == "Alien: Earth"
assert episode_point.payload.get("season_title") == "Season 1"
assert episode_point.payload.get("season_number") == 1
Expand All @@ -110,6 +113,7 @@ def test_run_processes_imdb_queue(monkeypatch, tmp_path):
monkeypatch.setattr(loader, "AsyncQdrantClient", CaptureClient)
CaptureClient.captured_points = []
CaptureClient.upsert_calls = 0
CaptureClient.close_calls = 0
queue_file = tmp_path / "queue.json"
queue_file.write_text(json.dumps(["tt0111161"]))
sample_dir = Path(__file__).resolve().parents[1] / "sample-data"
Expand Down Expand Up @@ -138,12 +142,25 @@ def test_run_upserts_in_batches(monkeypatch):
monkeypatch.setattr(loader, "AsyncQdrantClient", CaptureClient)
CaptureClient.captured_points = []
CaptureClient.upsert_calls = 0
CaptureClient.close_calls = 0
sample_dir = Path(__file__).resolve().parents[1] / "sample-data"
asyncio.run(_run_loader(sample_dir, qdrant_batch_size=1))
assert CaptureClient.upsert_calls == 2
assert len(CaptureClient.captured_points) == 2


def test_run_closes_client_once(monkeypatch):
monkeypatch.setattr(loader, "AsyncQdrantClient", CaptureClient)
CaptureClient.close_calls = 0
sample_dir = Path(__file__).resolve().parents[1] / "sample-data"

asyncio.run(_run_loader(sample_dir))
assert CaptureClient.close_calls == 1

asyncio.run(_run_loader(sample_dir))
assert CaptureClient.close_calls == 2


def test_run_raises_for_unknown_dense_model():
sample_dir = Path(__file__).resolve().parents[1] / "sample-data"

Expand Down