diff --git a/docker/pyproject.deps.toml b/docker/pyproject.deps.toml index 04f2ca3..a4c4695 100644 --- a/docker/pyproject.deps.toml +++ b/docker/pyproject.deps.toml @@ -1,6 +1,6 @@ [project] name = "mcp-plex" -version = "1.0.22" +version = "1.0.23" requires-python = ">=3.11,<3.13" dependencies = [ "fastmcp>=2.11.2", diff --git a/mcp_plex/server/__init__.py b/mcp_plex/server/__init__.py index af35087..6edef39 100644 --- a/mcp_plex/server/__init__.py +++ b/mcp_plex/server/__init__.py @@ -9,6 +9,7 @@ import logging import os import uuid +from dataclasses import dataclass from typing import Annotated, Any, Callable, Mapping, Sequence, cast from typing import NotRequired, TypedDict @@ -248,25 +249,33 @@ def clear_plex_identity_cache(self) -> None: server = PlexServer(settings=settings) -def _request_model(name: str, fn: Callable[..., Any]) -> type[BaseModel] | None: +def _request_model(name: str, fn: Callable[..., object]) -> type[BaseModel] | None: """Generate a Pydantic model representing the callable's parameters.""" signature = inspect.signature(fn) if not signature.parameters: return None - fields: dict[str, tuple[Any, Any]] = {} + fields: dict[str, tuple[object, object]] = {} for param_name, parameter in signature.parameters.items(): - annotation = ( - parameter.annotation - if parameter.annotation is not inspect._empty - else Any - ) - default = ( - parameter.default - if parameter.default is not inspect._empty - else ... - ) + if parameter.kind in { + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + }: + continue + + annotation: object + if parameter.annotation is inspect.Signature.empty: + annotation = object + else: + annotation = parameter.annotation + + default: object + if parameter.default is inspect.Signature.empty: + default = ... + else: + default = parameter.default + fields[param_name] = (annotation, default) if not fields: @@ -332,7 +341,7 @@ def _flatten_payload(payload: Mapping[str, JSONValue] | None) -> AggregatedMedia return cast(AggregatedMediaItem, data) -def _normalize_identifier(value: JSONValue) -> str | None: +def _normalize_identifier(value: str | int | float | None) -> str | None: """Convert mixed identifier formats into a normalized string.""" if value is None: @@ -376,19 +385,17 @@ async def _get_media_data(identifier: str) -> AggregatedMediaItem: cache_keys.add(lookup_key) plex_data = _extract_plex_metadata(data) - rating_key = _normalize_identifier(cast(JSONValue, plex_data.get("rating_key"))) + rating_key = _normalize_identifier(plex_data.get("rating_key")) if rating_key: cache_keys.add(rating_key) - guid = _normalize_identifier(cast(JSONValue, plex_data.get("guid"))) + guid = _normalize_identifier(plex_data.get("guid")) if guid: cache_keys.add(guid) for source_key in ("imdb", "tmdb", "tvdb"): source_value = data.get(source_key) if isinstance(source_value, dict): - source_id = _normalize_identifier( - cast(JSONValue, source_value.get("id")) - ) + source_id = _normalize_identifier(source_value.get("id")) if source_id: cache_keys.add(source_id) @@ -654,7 +661,7 @@ async def play_media( media = await _get_media_data(identifier) plex_info = _extract_plex_metadata(media) rating_key_value = plex_info.get("rating_key") - rating_key_normalized = _normalize_identifier(cast(JSONValue, rating_key_value)) + rating_key_normalized = _normalize_identifier(rating_key_value) if not rating_key_normalized: raise ValueError("Media item is missing a Plex rating key") @@ -736,9 +743,7 @@ async def search_media( async def _prefetch(hit: models.ScoredPoint) -> None: data = _flatten_payload(cast(Mapping[str, JSONValue] | None, hit.payload)) plex_info = _extract_plex_metadata(data) - rating_key = _normalize_identifier( - cast(JSONValue, plex_info.get("rating_key")) - ) + rating_key = _normalize_identifier(plex_info.get("rating_key")) if rating_key: server.cache.set_payload(rating_key, cast(dict[str, JSONValue], data)) thumb = plex_info.get("thumb") @@ -1361,7 +1366,7 @@ async def media_poster( if not thumb: raise ValueError("Poster not available") thumb_str = str(thumb) - rating_key = _normalize_identifier(cast(JSONValue, plex_info.get("rating_key"))) + rating_key = _normalize_identifier(plex_info.get("rating_key")) if rating_key: server.cache.set_poster(rating_key, thumb_str) return thumb_str @@ -1387,7 +1392,7 @@ async def media_background( if not art: raise ValueError("Background not available") art_str = str(art) - rating_key = _normalize_identifier(cast(JSONValue, plex_info.get("rating_key"))) + rating_key = _normalize_identifier(plex_info.get("rating_key")) if rating_key: server.cache.set_background(rating_key, art_str) return art_str @@ -1417,7 +1422,7 @@ async def rest_docs(request: Request) -> Response: return get_swagger_ui_html(openapi_url="/openapi.json", title="MCP REST API") -def _build_openapi_schema() -> dict[str, Any]: +def _build_openapi_schema() -> dict[str, object]: app = FastAPI() for name, tool in server._tool_manager._tools.items(): request_model = _request_model(name, tool.fn) @@ -1439,7 +1444,7 @@ async def _tool_stub(payload: request_model) -> None: # type: ignore[name-defin annotation=request_model, ) ], - return_annotation=Any, + return_annotation=inspect.Signature.empty, ) app.post(f"/rest/{name}")(_tool_stub) @@ -1449,9 +1454,10 @@ async def _p_stub(**kwargs): # noqa: ARG001 _p_stub.__name__ = f"prompt_{name.replace('-', '_')}" _p_stub.__doc__ = prompt.fn.__doc__ request_model = _request_model(name, prompt.fn) + prompt_signature = inspect.signature(prompt.fn) if request_model is None: - _p_stub.__signature__ = inspect.signature(prompt.fn).replace( - return_annotation=Any + _p_stub.__signature__ = prompt_signature.replace( + return_annotation=inspect.Signature.empty ) else: _p_stub.__signature__ = inspect.Signature( @@ -1462,7 +1468,7 @@ async def _p_stub(**kwargs): # noqa: ARG001 annotation=request_model, ) ], - return_annotation=Any, + return_annotation=inspect.Signature.empty, ) app.post(f"/rest/prompt/{name}")(_p_stub) for uri, resource in server._resource_manager._templates.items(): @@ -1472,7 +1478,7 @@ async def _r_stub(**kwargs): # noqa: ARG001 _r_stub.__name__ = f"resource_{path.replace('/', '_').replace('{', '').replace('}', '')}" _r_stub.__doc__ = resource.fn.__doc__ _r_stub.__signature__ = inspect.signature(resource.fn).replace( - return_annotation=Any + return_annotation=inspect.Signature.empty ) app.get(f"/rest/resource/{path}")(_r_stub) return get_openapi(title="MCP REST API", version="1.0.0", routes=app.routes) @@ -1492,7 +1498,9 @@ def _register_rest_endpoints() -> None: def _register(path: str, method: str, handler: Callable, fn: Callable, name: str) -> None: handler.__name__ = name handler.__doc__ = fn.__doc__ - handler.__signature__ = inspect.signature(fn).replace(return_annotation=Any) + handler.__signature__ = inspect.signature(fn).replace( + return_annotation=inspect.Signature.empty + ) server.custom_route(path, methods=[method])(handler) for name, tool in server._tool_manager._tools.items(): @@ -1560,6 +1568,27 @@ async def _rest_resource(request: Request, _uri_template=uri, _resource=resource _register_rest_endpoints() +@dataclass +class RunConfig: + """Runtime configuration for FastMCP transport servers.""" + + host: str | None = None + port: int | None = None + path: str | None = None + + def to_kwargs(self) -> dict[str, object]: + """Return keyword arguments compatible with ``FastMCP.run``.""" + + kwargs: dict[str, object] = {} + if self.host is not None: + kwargs["host"] = self.host + if self.port is not None: + kwargs["port"] = self.port + if self.path: + kwargs["path"] = self.path + return kwargs + + def main(argv: list[str] | None = None) -> None: """CLI entrypoint for running the MCP server.""" parser = argparse.ArgumentParser(description="Run the MCP server") @@ -1616,16 +1645,19 @@ def main(argv: list[str] | None = None) -> None: if transport == "stdio" and mount: parser.error("--mount or MCP_MOUNT is not allowed when transport is stdio") - run_kwargs: dict[str, Any] = {} + run_config = RunConfig() if transport != "stdio": - run_kwargs.update({"host": host, "port": port}) + if host is not None: + run_config.host = host + if port is not None: + run_config.port = port if mount: - run_kwargs["path"] = mount + run_config.path = mount server.settings.dense_model = args.dense_model server.settings.sparse_model = args.sparse_model - server.run(transport=transport, **run_kwargs) + server.run(transport=transport, **run_config.to_kwargs()) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index ab4f7b9..d4cb481 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mcp-plex" -version = "1.0.22" +version = "1.0.23" description = "Plex-Oriented Model Context Protocol Server" requires-python = ">=3.11,<3.13" diff --git a/tests/test_server.py b/tests/test_server.py index b326953..27312ce 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -18,6 +18,7 @@ from mcp_plex import loader from mcp_plex import server as server_module +from pydantic import ValidationError @contextmanager @@ -517,6 +518,47 @@ async def _noop() -> None: assert module._request_model("noop", _noop) is None +def test_request_model_missing_annotation_uses_object(): + module = importlib.import_module("mcp_plex.server") + + async def _unannotated(foo): # type: ignore[no-untyped-def] + return foo + + request_model = module._request_model("unannotated", _unannotated) + assert request_model is not None + field = request_model.model_fields["foo"] + assert field.annotation is object + with pytest.raises(ValidationError): + request_model() + instance = request_model(foo="value") + assert instance.foo == "value" + + +def test_normalize_identifier_scalar_inputs(): + module = importlib.import_module("mcp_plex.server") + + assert module._normalize_identifier(" value ") == "value" + assert module._normalize_identifier(123) == "123" + assert module._normalize_identifier(0.0) == "0.0" + assert module._normalize_identifier("") is None + assert module._normalize_identifier(None) is None + + +def test_run_config_to_kwargs(): + module = importlib.import_module("mcp_plex.server") + + config = module.RunConfig() + assert config.to_kwargs() == {} + + config.host = "127.0.0.1" + config.port = 8080 + assert config.to_kwargs() == {"host": "127.0.0.1", "port": 8080} + + config.path = "/plex" + kwargs = config.to_kwargs() + assert kwargs["path"] == "/plex" + + def test_find_records_handles_retrieve_error(monkeypatch): with _load_server(monkeypatch) as module: async def fail_retrieve(*args, **kwargs): diff --git a/uv.lock b/uv.lock index 10127a4..2f7701f 100644 --- a/uv.lock +++ b/uv.lock @@ -730,7 +730,7 @@ wheels = [ [[package]] name = "mcp-plex" -version = "1.0.22" +version = "1.0.23" source = { editable = "." } dependencies = [ { name = "fastapi" },