diff --git a/mcp_plex/config.py b/mcp_plex/config.py new file mode 100644 index 0000000..b1b50ce --- /dev/null +++ b/mcp_plex/config.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + """Application configuration settings.""" + + qdrant_url: str | None = Field(default=None, env="QDRANT_URL") + qdrant_api_key: str | None = Field(default=None, env="QDRANT_API_KEY") + qdrant_host: str | None = Field(default=None, env="QDRANT_HOST") + qdrant_port: int = Field(default=6333, env="QDRANT_PORT") + qdrant_grpc_port: int = Field(default=6334, env="QDRANT_GRPC_PORT") + qdrant_prefer_grpc: bool = Field(default=False, env="QDRANT_PREFER_GRPC") + qdrant_https: bool | None = Field(default=None, env="QDRANT_HTTPS") + dense_model: str = Field( + default="BAAI/bge-small-en-v1.5", env="DENSE_MODEL" + ) + sparse_model: str = Field( + default="Qdrant/bm42-all-minilm-l6-v2-attentions", env="SPARSE_MODEL" + ) + cache_size: int = Field(default=128, env="CACHE_SIZE") + use_reranker: bool = Field(default=True, env="USE_RERANKER") + + model_config = SettingsConfigDict(case_sensitive=False) diff --git a/mcp_plex/server.py b/mcp_plex/server.py index 82a771d..d3426db 100644 --- a/mcp_plex/server.py +++ b/mcp_plex/server.py @@ -5,7 +5,6 @@ import asyncio import inspect import json -import os from typing import Annotated, Any, Callable from fastapi import FastAPI @@ -21,56 +20,60 @@ from starlette.responses import JSONResponse, PlainTextResponse, Response from .cache import MediaCache +from .config import Settings try: from sentence_transformers import CrossEncoder except Exception: CrossEncoder = None -# Environment configuration for Qdrant -_QDRANT_URL = os.getenv("QDRANT_URL") -_QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") -_QDRANT_HOST = os.getenv("QDRANT_HOST") -_QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333")) -_QDRANT_GRPC_PORT = int(os.getenv("QDRANT_GRPC_PORT", "6334")) -_QDRANT_PREFER_GRPC = os.getenv("QDRANT_PREFER_GRPC", "0") == "1" -_https_env = os.getenv("QDRANT_HTTPS") -_QDRANT_HTTPS = None if _https_env is None else _https_env == "1" -# Embedding model configuration -_DENSE_MODEL_NAME = os.getenv("DENSE_MODEL", "BAAI/bge-small-en-v1.5") -_SPARSE_MODEL_NAME = os.getenv( - "SPARSE_MODEL", "Qdrant/bm42-all-minilm-l6-v2-attentions" -) - -if _QDRANT_URL is None and _QDRANT_HOST is None: - _QDRANT_URL = ":memory:" - - -_CACHE_SIZE = 128 +settings = Settings() class PlexServer(FastMCP): """FastMCP server with an attached Qdrant client.""" - def __init__(self) -> None: # noqa: D401 - short description inherited - super().__init__() - self.qdrant_client = AsyncQdrantClient( - location=_QDRANT_URL, - api_key=_QDRANT_API_KEY, - host=_QDRANT_HOST, - port=_QDRANT_PORT, - grpc_port=_QDRANT_GRPC_PORT, - prefer_grpc=_QDRANT_PREFER_GRPC, - https=_QDRANT_HTTPS, + def __init__( + self, + *, + settings: Settings | None = None, + qdrant_client: AsyncQdrantClient | None = None, + ) -> None: # noqa: D401 - short description inherited + self._settings = settings or Settings() + location = self.settings.qdrant_url + host = self.settings.qdrant_host + if location is None and host is None: + location = ":memory:" + self.qdrant_client = qdrant_client or AsyncQdrantClient( + location=location, + api_key=self.settings.qdrant_api_key, + host=host, + port=self.settings.qdrant_port, + grpc_port=self.settings.qdrant_grpc_port, + prefer_grpc=self.settings.qdrant_prefer_grpc, + https=self.settings.qdrant_https, ) + + async def _lifespan(app: FastMCP): # noqa: ARG001 + yield + await self.close() + + super().__init__(lifespan=_lifespan) self._reranker: CrossEncoder | None = None self._reranker_loaded = False - self.cache = MediaCache(_CACHE_SIZE) + self.cache = MediaCache(self.settings.cache_size) + + async def close(self) -> None: + await self.qdrant_client.close() + + @property + def settings(self) -> Settings: # type: ignore[override] + return self._settings @property def reranker(self) -> CrossEncoder | None: - if not _USE_RERANKER or CrossEncoder is None: + if not self.settings.use_reranker or CrossEncoder is None: return None if not self._reranker_loaded: try: @@ -83,9 +86,7 @@ def reranker(self) -> CrossEncoder | None: return self._reranker -_USE_RERANKER = os.getenv("USE_RERANKER", "1") == "1" - -server = PlexServer() +server = PlexServer(settings=settings) async def _find_records(identifier: str, limit: int = 5) -> list[models.Record]: @@ -181,8 +182,8 @@ async def search_media( ] = 5, ) -> list[dict[str, Any]]: """Hybrid similarity search across media items using dense and sparse vectors.""" - dense_doc = models.Document(text=query, model=_DENSE_MODEL_NAME) - sparse_doc = models.Document(text=query, model=_SPARSE_MODEL_NAME) + dense_doc = models.Document(text=query, model=server.settings.dense_model) + sparse_doc = models.Document(text=query, model=server.settings.sparse_model) reranker = server.reranker candidate_limit = limit * 3 if reranker is not None else limit prefetch = [ @@ -618,7 +619,6 @@ async def _rest_resource(request: Request, _uri_template=uri, _resource=resource def main(argv: list[str] | None = None) -> None: """CLI entrypoint for running the MCP server.""" - global _DENSE_MODEL_NAME, _SPARSE_MODEL_NAME parser = argparse.ArgumentParser(description="Run the MCP server") parser.add_argument("--bind", help="Host address to bind to") parser.add_argument("--port", type=int, help="Port to listen on") @@ -631,12 +631,12 @@ def main(argv: list[str] | None = None) -> None: parser.add_argument("--mount", help="Mount path for HTTP transports") parser.add_argument( "--dense-model", - default=_DENSE_MODEL_NAME, + default=server.settings.dense_model, help="Dense embedding model name (env: DENSE_MODEL)", ) parser.add_argument( "--sparse-model", - default=_SPARSE_MODEL_NAME, + default=server.settings.sparse_model, help="Sparse embedding model name (env: SPARSE_MODEL)", ) args = parser.parse_args(argv) @@ -653,8 +653,8 @@ def main(argv: list[str] | None = None) -> None: if args.mount: run_kwargs["path"] = args.mount - _DENSE_MODEL_NAME = args.dense_model - _SPARSE_MODEL_NAME = args.sparse_model + server.settings.dense_model = args.dense_model + server.settings.sparse_model = args.sparse_model server.run(transport=args.transport, **run_kwargs) diff --git a/pyproject.toml b/pyproject.toml index ae54d1b..5541645 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mcp-plex" -version = "0.26.11" +version = "0.26.12" description = "Plex-Oriented Model Context Protocol Server" requires-python = ">=3.11,<3.13" diff --git a/tests/test_server.py b/tests/test_server.py index 04360e6..ad1c7c7 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -5,6 +5,7 @@ import json import sys import types +from contextlib import contextmanager from pathlib import Path import builtins @@ -14,6 +15,7 @@ from mcp_plex import loader +@contextmanager def _load_server(monkeypatch): from qdrant_client import async_qdrant_client @@ -36,7 +38,11 @@ def __init__(self, *args, **kwargs): monkeypatch.setattr(async_qdrant_client, "AsyncQdrantClient", SharedClient) sample_dir = Path(__file__).resolve().parents[1] / "sample-data" asyncio.run(loader.run(None, None, None, sample_dir, None, None)) - return importlib.reload(importlib.import_module("mcp_plex.server")) + module = importlib.reload(importlib.import_module("mcp_plex.server")) + try: + yield module + finally: + asyncio.run(module.server.close()) def test_qdrant_env_config(monkeypatch): @@ -47,6 +53,8 @@ def test_qdrant_env_config(monkeypatch): class CaptureClient: def __init__(self, *args, **kwargs): captured.update(kwargs) + async def close(self): + pass monkeypatch.setattr(async_qdrant_client, "AsyncQdrantClient", CaptureClient) monkeypatch.setenv("QDRANT_HOST", "example.com") @@ -62,79 +70,80 @@ def __init__(self, *args, **kwargs): assert captured["prefer_grpc"] is True assert captured["https"] is True assert hasattr(module.server, "qdrant_client") + asyncio.run(module.server.close()) def test_server_tools(monkeypatch): - server = _load_server(monkeypatch) - - movie_id = "49915" - res = asyncio.run(server.get_media.fn(identifier=movie_id)) - assert res and res[0]["plex"]["title"] == "The Gentlemen" - - res = asyncio.run(server.get_media.fn(identifier="tt8367814")) - assert res and res[0]["plex"]["rating_key"] == movie_id - - poster = asyncio.run(server.media_poster.fn(identifier=movie_id)) - assert isinstance(poster, str) and "thumb" in poster - assert server.server.cache.get_poster(movie_id) == poster - - art = asyncio.run(server.media_background.fn(identifier=movie_id)) - assert isinstance(art, str) and "art" in art - assert server.server.cache.get_background(movie_id) == art - - item = json.loads(asyncio.run(server.media_item.fn(identifier=movie_id))) - assert item["plex"]["rating_key"] == movie_id - assert ( - server.server.cache.get_payload(movie_id)["plex"]["rating_key"] == movie_id - ) + with _load_server(monkeypatch) as server: + movie_id = "49915" + res = asyncio.run(server.get_media.fn(identifier=movie_id)) + assert res and res[0]["plex"]["title"] == "The Gentlemen" + + res = asyncio.run(server.get_media.fn(identifier="tt8367814")) + assert res and res[0]["plex"]["rating_key"] == movie_id + + poster = asyncio.run(server.media_poster.fn(identifier=movie_id)) + assert isinstance(poster, str) and "thumb" in poster + assert server.server.cache.get_poster(movie_id) == poster + + art = asyncio.run(server.media_background.fn(identifier=movie_id)) + assert isinstance(art, str) and "art" in art + assert server.server.cache.get_background(movie_id) == art + + item = json.loads(asyncio.run(server.media_item.fn(identifier=movie_id))) + assert item["plex"]["rating_key"] == movie_id + assert ( + server.server.cache.get_payload(movie_id)["plex"]["rating_key"] + == movie_id + ) - ids = json.loads(asyncio.run(server.media_ids.fn(identifier=movie_id))) - assert ids["imdb"] == "tt8367814" + ids = json.loads(asyncio.run(server.media_ids.fn(identifier=movie_id))) + assert ids["imdb"] == "tt8367814" - res = asyncio.run(server.search_media.fn(query="Matthew McConaughey crime movie", limit=1)) - assert res and res[0]["plex"]["title"] == "The Gentlemen" + res = asyncio.run( + server.search_media.fn(query="Matthew McConaughey crime movie", limit=1) + ) + assert res and res[0]["plex"]["title"] == "The Gentlemen" - rec = asyncio.run(server.recommend_media.fn(identifier=movie_id, limit=1)) - assert rec and rec[0]["plex"]["rating_key"] == "61960" + rec = asyncio.run(server.recommend_media.fn(identifier=movie_id, limit=1)) + assert rec and rec[0]["plex"]["rating_key"] == "61960" - assert asyncio.run(server.recommend_media.fn(identifier="0", limit=1)) == [] + assert asyncio.run(server.recommend_media.fn(identifier="0", limit=1)) == [] - with pytest.raises(ValueError): - asyncio.run(server.media_item.fn(identifier="0")) - with pytest.raises(ValueError): - asyncio.run(server.media_ids.fn(identifier="0")) - with pytest.raises(ValueError): - asyncio.run(server.media_poster.fn(identifier="0")) - with pytest.raises(ValueError): - asyncio.run(server.media_background.fn(identifier="0")) + with pytest.raises(ValueError): + asyncio.run(server.media_item.fn(identifier="0")) + with pytest.raises(ValueError): + asyncio.run(server.media_ids.fn(identifier="0")) + with pytest.raises(ValueError): + asyncio.run(server.media_poster.fn(identifier="0")) + with pytest.raises(ValueError): + asyncio.run(server.media_background.fn(identifier="0")) def test_new_media_tools(monkeypatch): - server = _load_server(monkeypatch) - - movies = asyncio.run(server.new_movies.fn(limit=1)) - assert movies and movies[0]["plex"]["type"] == "movie" - assert movies[0]["plex"]["added_at"] is not None + with _load_server(monkeypatch) as server: + movies = asyncio.run(server.new_movies.fn(limit=1)) + assert movies and movies[0]["plex"]["type"] == "movie" + assert movies[0]["plex"]["added_at"] is not None - shows = asyncio.run(server.new_shows.fn(limit=1)) - assert shows and shows[0]["plex"]["type"] == "episode" - assert shows[0]["plex"]["added_at"] is not None + shows = asyncio.run(server.new_shows.fn(limit=1)) + assert shows and shows[0]["plex"]["type"] == "episode" + assert shows[0]["plex"]["added_at"] is not None def test_actor_movies(monkeypatch): - server = _load_server(monkeypatch) - - movies = asyncio.run( - server.actor_movies.fn(actor="Matthew McConaughey", limit=1) - ) - assert movies and movies[0]["plex"]["title"] == "The Gentlemen" + with _load_server(monkeypatch) as server: + movies = asyncio.run( + server.actor_movies.fn(actor="Matthew McConaughey", limit=1) + ) + assert movies and movies[0]["plex"]["title"] == "The Gentlemen" - none = asyncio.run( - server.actor_movies.fn( - actor="Matthew McConaughey", year_from=1990, year_to=1999 + none = asyncio.run( + server.actor_movies.fn( + actor="Matthew McConaughey", year_from=1990, year_to=1999 + ) ) - ) - assert none == [] + assert none == [] def test_reranker_import_failure(monkeypatch): @@ -149,6 +158,7 @@ def fake_import(name, *args, **kwargs): monkeypatch.setattr(builtins, "__import__", fake_import) module = importlib.reload(importlib.import_module("mcp_plex.server")) assert module.server.reranker is None + asyncio.run(module.server.close()) def test_reranker_init_failure(monkeypatch): @@ -163,33 +173,34 @@ def __init__(self, *args, **kwargs): monkeypatch.setitem(sys.modules, "sentence_transformers", st_module) module = importlib.reload(importlib.import_module("mcp_plex.server")) assert module.server.reranker is None + asyncio.run(module.server.close()) def test_rest_endpoints(monkeypatch): - module = _load_server(monkeypatch) - client = TestClient(module.server.http_app()) - - resp = client.post("/rest/get-media", json={"identifier": "49915"}) - assert resp.status_code == 200 - assert resp.json()[0]["plex"]["rating_key"] == "49915" - - resp = client.post("/rest/prompt/media-info", json={"identifier": "49915"}) - assert resp.status_code == 200 - msg = resp.json()[0] - assert msg["role"] == "user" - assert "The Gentlemen" in msg["content"]["text"] - - resp = client.get("/rest/resource/media-ids/49915") - assert resp.status_code == 200 - assert resp.json()["rating_key"] == "49915" - - spec = client.get("/openapi.json").json() - get_media = spec["paths"]["/rest/get-media"]["post"] - assert get_media["description"].startswith("Retrieve media items") - params = {p["name"]: p for p in get_media["parameters"]} - assert params["identifier"]["schema"]["description"].startswith("Rating key") - assert "/rest/prompt/media-info" in spec["paths"] - assert "/rest/resource/media-ids/{identifier}" in spec["paths"] - - resp = client.get("/rest") - assert resp.status_code == 200 + with _load_server(monkeypatch) as module: + client = TestClient(module.server.http_app()) + + resp = client.post("/rest/get-media", json={"identifier": "49915"}) + assert resp.status_code == 200 + assert resp.json()[0]["plex"]["rating_key"] == "49915" + + resp = client.post("/rest/prompt/media-info", json={"identifier": "49915"}) + assert resp.status_code == 200 + msg = resp.json()[0] + assert msg["role"] == "user" + assert "The Gentlemen" in msg["content"]["text"] + + resp = client.get("/rest/resource/media-ids/49915") + assert resp.status_code == 200 + assert resp.json()["rating_key"] == "49915" + + spec = client.get("/openapi.json").json() + get_media = spec["paths"]["/rest/get-media"]["post"] + assert get_media["description"].startswith("Retrieve media items") + params = {p["name"]: p for p in get_media["parameters"]} + assert params["identifier"]["schema"]["description"].startswith("Rating key") + assert "/rest/prompt/media-info" in spec["paths"] + assert "/rest/resource/media-ids/{identifier}" in spec["paths"] + + resp = client.get("/rest") + assert resp.status_code == 200 diff --git a/tests/test_server_cli.py b/tests/test_server_cli.py index 0b52068..f8862b5 100644 --- a/tests/test_server_cli.py +++ b/tests/test_server_cli.py @@ -1,10 +1,18 @@ from unittest.mock import patch +import asyncio +import importlib import pytest from mcp_plex import server +@pytest.fixture(scope="module", autouse=True) +def close_server_module(): + yield + asyncio.run(server.server.close()) + + def test_main_stdio_runs(): with patch.object(server.server, "run") as mock_run: server.main([]) @@ -37,19 +45,19 @@ def test_main_model_overrides(): "--sparse-model", "bar", ]) - assert server._DENSE_MODEL_NAME == "foo" - assert server._SPARSE_MODEL_NAME == "bar" + assert server.settings.dense_model == "foo" + assert server.settings.sparse_model == "bar" mock_run.assert_called_once_with(transport="stdio") def test_env_model_overrides(monkeypatch): monkeypatch.setenv("DENSE_MODEL", "foo") monkeypatch.setenv("SPARSE_MODEL", "bar") - import importlib - + asyncio.run(server.server.close()) importlib.reload(server) - assert server._DENSE_MODEL_NAME == "foo" - assert server._SPARSE_MODEL_NAME == "bar" + assert server.settings.dense_model == "foo" + assert server.settings.sparse_model == "bar" # reload to reset globals + asyncio.run(server.server.close()) importlib.reload(server) diff --git a/uv.lock b/uv.lock index b7e6863..3898f30 100644 --- a/uv.lock +++ b/uv.lock @@ -690,7 +690,7 @@ wheels = [ [[package]] name = "mcp-plex" -version = "0.26.11" +version = "0.26.12" source = { editable = "." } dependencies = [ { name = "fastapi" },