diff --git a/mcp_plex/loader.py b/mcp_plex/loader.py index 5dafee4..445539a 100644 --- a/mcp_plex/loader.py +++ b/mcp_plex/loader.py @@ -460,7 +460,7 @@ def main( ) if not continuous: break - asyncio.sleep(delay) + asyncio.run(asyncio.sleep(delay)) if __name__ == "__main__": diff --git a/tests/test_load_from_plex.py b/tests/test_load_from_plex.py new file mode 100644 index 0000000..136073d --- /dev/null +++ b/tests/test_load_from_plex.py @@ -0,0 +1,78 @@ +import asyncio +import types +import httpx + +from mcp_plex import loader +from mcp_plex.types import TMDBShow + + +def test_load_from_plex(monkeypatch): + movie = types.SimpleNamespace( + ratingKey="1", + guid="g1", + type="movie", + title="Movie", + guids=[ + types.SimpleNamespace(id="imdb://ttm"), + types.SimpleNamespace(id="tmdb://1"), + ], + ) + + ep1 = types.SimpleNamespace( + ratingKey="2", + guid="g2", + type="episode", + title="Ep1", + guids=[ + types.SimpleNamespace(id="imdb://tt1"), + types.SimpleNamespace(id="tmdb://2"), + ], + ) + ep2 = types.SimpleNamespace( + ratingKey="3", + guid="g3", + type="episode", + title="Ep2", + guids=[types.SimpleNamespace(id="imdb://tt2")], + ) + + show = types.SimpleNamespace( + guids=[types.SimpleNamespace(id="tmdb://3")], + episodes=lambda: [ep1, ep2], + ) + + movie_section = types.SimpleNamespace(all=lambda: [movie]) + show_section = types.SimpleNamespace(all=lambda: [show]) + library = types.SimpleNamespace( + section=lambda name: movie_section if name == "Movies" else show_section + ) + server = types.SimpleNamespace(library=library) + + async def handler(request): + url = str(request.url) + if "imdbapi" in url: + return httpx.Response( + 200, json={"id": "tt", "type": "movie", "primaryTitle": "IMDb"} + ) + if "/movie/1" in url: + return httpx.Response(200, json={"id": 1, "title": "TMDB Movie"}) + if "/tv/3" in url: + return httpx.Response(200, json={"id": 3, "name": "TMDB Show"}) + if "/episode/2" in url: + return httpx.Response(200, json={"id": 2, "name": "TMDB Ep"}) + return httpx.Response(404) + + transport = httpx.MockTransport(handler) + orig_client = httpx.AsyncClient + monkeypatch.setattr( + loader.httpx, + "AsyncClient", + lambda *args, **kwargs: orig_client(transport=transport), + ) + + items = asyncio.run(loader._load_from_plex(server, "key")) + assert len(items) == 3 + assert items[0].tmdb and items[0].tmdb.id == 1 + assert items[1].tmdb and items[1].tmdb.id == 2 + assert isinstance(items[2].tmdb, TMDBShow) + assert items[2].tmdb.id == 3 diff --git a/tests/test_loader_cli.py b/tests/test_loader_cli.py new file mode 100644 index 0000000..b659c6c --- /dev/null +++ b/tests/test_loader_cli.py @@ -0,0 +1,46 @@ +import asyncio +from click.testing import CliRunner +import pytest + +from mcp_plex import loader + + +def test_cli_continuous_respects_delay(monkeypatch): + actions: list = [] + run_calls = 0 + + async def fake_run(*args, **kwargs): + nonlocal run_calls + run_calls += 1 + actions.append("run") + if run_calls >= 2: + raise RuntimeError("stop") + + async def fake_sleep(seconds): + actions.append(("sleep", seconds)) + + monkeypatch.setattr(loader, "run", fake_run) + monkeypatch.setattr(asyncio, "sleep", fake_sleep) + + runner = CliRunner() + with pytest.raises(RuntimeError, match="stop"): + runner.invoke(loader.main, ["--continuous", "--delay", "7.5"], catch_exceptions=False) + + assert actions == ["run", ("sleep", 7.5), "run"] + + +def test_cli_invalid_delay_value(): + runner = CliRunner() + result = runner.invoke(loader.main, ["--delay", "not-a-number"]) + assert result.exit_code != 0 + assert "Invalid value for '--delay'" in result.output + + +def test_run_requires_credentials(monkeypatch): + monkeypatch.setattr(loader, "PlexServer", object) + + async def invoke(): + await loader.run(None, None, "key", None, None, None) + + with pytest.raises(RuntimeError, match="PLEX_URL and PLEX_TOKEN must be provided"): + asyncio.run(invoke()) diff --git a/tests/test_loader_integration.py b/tests/test_loader_integration.py index 79377f6..d75b34c 100644 --- a/tests/test_loader_integration.py +++ b/tests/test_loader_integration.py @@ -64,6 +64,27 @@ async def upsert(self, collection_name: str, points): self.upserted.extend(points) +class TrackingQdrantClient(DummyQdrantClient): + """Qdrant client that starts with a mismatched collection size.""" + + def __init__(self, url: str, api_key: str | None = None): + super().__init__(url, api_key) + # Pre-create a collection with the wrong vector size to force recreation + wrong_params = SimpleNamespace( + vectors={ + "dense": models.VectorParams(size=99, distance=models.Distance.COSINE) + } + ) + self.collections["media-items"] = SimpleNamespace( + config=SimpleNamespace(params=wrong_params) + ) + self.deleted = False + + async def delete_collection(self, name: str): + self.deleted = True + await super().delete_collection(name) + + async def _run_loader(sample_dir: Path): await loader.run(None, None, None, sample_dir, None, None) @@ -79,3 +100,19 @@ def test_run_writes_points(monkeypatch): assert len(client.upserted) == 2 payloads = [p.payload for p in client.upserted] assert all("title" in p and "type" in p for p in payloads) + + +def test_run_recreates_mismatched_collection(monkeypatch): + monkeypatch.setattr(loader, "TextEmbedding", DummyTextEmbedding) + monkeypatch.setattr(loader, "SparseTextEmbedding", DummySparseEmbedding) + monkeypatch.setattr(loader, "AsyncQdrantClient", TrackingQdrantClient) + sample_dir = Path(__file__).resolve().parents[1] / "sample-data" + asyncio.run(_run_loader(sample_dir)) + client = TrackingQdrantClient.instance + assert client is not None + # The pre-created collection should have been deleted and recreated + assert client.deleted is True + assert ( + client.collections["media-items"].config.params.vectors["dense"].size + == 3 + )