diff --git a/docker/pyproject.deps.toml b/docker/pyproject.deps.toml index c968eaf..8347a4c 100644 --- a/docker/pyproject.deps.toml +++ b/docker/pyproject.deps.toml @@ -1,6 +1,6 @@ [project] name = "mcp-plex" -version = "0.26.34" +version = "0.26.35" requires-python = ">=3.11,<3.13" dependencies = [ "fastmcp>=2.11.2", diff --git a/mcp_plex/server.py b/mcp_plex/server.py index 89e32fe..69513d2 100644 --- a/mcp_plex/server.py +++ b/mcp_plex/server.py @@ -14,7 +14,7 @@ from fastmcp.prompts import Message from fastmcp.server import FastMCP from fastmcp.server.context import Context as FastMCPContext -from pydantic import Field +from pydantic import BaseModel, Field, create_model from qdrant_client import models from qdrant_client.async_qdrant_client import AsyncQdrantClient from starlette.requests import Request @@ -99,6 +99,36 @@ def reranker(self) -> CrossEncoder | None: server = PlexServer(settings=settings) +def _request_model(name: str, fn: Callable[..., Any]) -> type[BaseModel] | None: + """Generate a Pydantic model representing the callable's parameters.""" + + signature = inspect.signature(fn) + if not signature.parameters: + return None + + fields: dict[str, tuple[Any, Any]] = {} + for param_name, parameter in signature.parameters.items(): + annotation = ( + parameter.annotation + if parameter.annotation is not inspect._empty + else Any + ) + default = ( + parameter.default + if parameter.default is not inspect._empty + else ... + ) + fields[param_name] = (annotation, default) + + if not fields: + return None + + model_name = "".join(part.capitalize() for part in name.replace("-", "_").split("_")) + model_name = f"{model_name or 'Request'}Request" + request_model = create_model(model_name, **fields) # type: ignore[arg-type] + return request_model + + async def _find_records(identifier: str, limit: int = 5) -> list[models.Record]: """Locate records matching an identifier or title.""" # First, try direct ID lookup @@ -522,15 +552,50 @@ async def rest_docs(request: Request) -> Response: def _build_openapi_schema() -> dict[str, Any]: app = FastAPI() for name, tool in server._tool_manager._tools.items(): - app.post(f"/rest/{name}")(tool.fn) + request_model = _request_model(name, tool.fn) + + if request_model is None: + app.post(f"/rest/{name}")(tool.fn) + continue + + async def _tool_stub(payload: request_model) -> None: # type: ignore[name-defined] + pass + + _tool_stub.__name__ = f"tool_{name.replace('-', '_')}" + _tool_stub.__doc__ = tool.fn.__doc__ + _tool_stub.__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter( + "payload", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=request_model, + ) + ], + return_annotation=Any, + ) + + app.post(f"/rest/{name}")(_tool_stub) for name, prompt in server._prompt_manager._prompts.items(): async def _p_stub(**kwargs): # noqa: ARG001 pass _p_stub.__name__ = f"prompt_{name.replace('-', '_')}" _p_stub.__doc__ = prompt.fn.__doc__ - _p_stub.__signature__ = inspect.signature(prompt.fn).replace( - return_annotation=Any - ) + request_model = _request_model(name, prompt.fn) + if request_model is None: + _p_stub.__signature__ = inspect.signature(prompt.fn).replace( + return_annotation=Any + ) + else: + _p_stub.__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter( + "payload", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=request_model, + ) + ], + return_annotation=Any, + ) app.post(f"/rest/prompt/{name}")(_p_stub) for uri, resource in server._resource_manager._templates.items(): path = uri.replace("resource://", "") diff --git a/pyproject.toml b/pyproject.toml index db7b967..74b6053 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mcp-plex" -version = "0.26.34" +version = "0.26.35" description = "Plex-Oriented Model Context Protocol Server" requires-python = ">=3.11,<3.13" diff --git a/tests/test_server.py b/tests/test_server.py index 4acc7c8..580b104 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -195,10 +195,32 @@ def test_rest_endpoints(monkeypatch): assert resp.json()["rating_key"] == "49915" spec = client.get("/openapi.json").json() + def _resolve(schema: dict): + if "$ref" in schema: + ref = schema["$ref"].split("/")[-1] + return spec["components"]["schemas"][ref] + return schema + get_media = spec["paths"]["/rest/get-media"]["post"] assert get_media["description"].startswith("Retrieve media items") - params = {p["name"]: p for p in get_media["parameters"]} - assert params["identifier"]["schema"]["description"].startswith("Rating key") + assert "parameters" not in get_media or not get_media["parameters"] + get_media_schema = get_media["requestBody"]["content"]["application/json"][ + "schema" + ] + get_media_schema = _resolve(get_media_schema) + assert ( + get_media_schema["properties"]["identifier"]["description"].startswith( + "Rating key" + ) + ) + + search_media = spec["paths"]["/rest/search-media"]["post"] + assert "parameters" not in search_media or not search_media["parameters"] + search_schema = search_media["requestBody"]["content"][ + "application/json" + ]["schema"] + search_schema = _resolve(search_schema) + assert "query" in search_schema["required"] assert "/rest/prompt/media-info" in spec["paths"] assert "/rest/resource/media-ids/{identifier}" in spec["paths"] diff --git a/uv.lock b/uv.lock index c59b28a..dd5d266 100644 --- a/uv.lock +++ b/uv.lock @@ -690,7 +690,7 @@ wheels = [ [[package]] name = "mcp-plex" -version = "0.26.34" +version = "0.26.35" source = { editable = "." } dependencies = [ { name = "fastapi" },