Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/pyproject.deps.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "mcp-plex"
version = "1.0.22"
version = "1.0.23"
requires-python = ">=3.11,<3.13"
dependencies = [
"fastmcp>=2.11.2",
Expand Down
102 changes: 67 additions & 35 deletions mcp_plex/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import os
import uuid
from dataclasses import dataclass
from typing import Annotated, Any, Callable, Mapping, Sequence, cast
from typing import NotRequired, TypedDict

Expand Down Expand Up @@ -248,25 +249,33 @@ def clear_plex_identity_cache(self) -> None:
server = PlexServer(settings=settings)


def _request_model(name: str, fn: Callable[..., Any]) -> type[BaseModel] | None:
def _request_model(name: str, fn: Callable[..., object]) -> type[BaseModel] | None:
"""Generate a Pydantic model representing the callable's parameters."""

signature = inspect.signature(fn)
if not signature.parameters:
return None

fields: dict[str, tuple[Any, Any]] = {}
fields: dict[str, tuple[object, object]] = {}
for param_name, parameter in signature.parameters.items():
annotation = (
parameter.annotation
if parameter.annotation is not inspect._empty
else Any
)
default = (
parameter.default
if parameter.default is not inspect._empty
else ...
)
if parameter.kind in {
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
}:
continue

annotation: object
if parameter.annotation is inspect.Signature.empty:
annotation = object
else:
annotation = parameter.annotation

default: object
if parameter.default is inspect.Signature.empty:
default = ...
else:
default = parameter.default

fields[param_name] = (annotation, default)

if not fields:
Expand Down Expand Up @@ -332,7 +341,7 @@ def _flatten_payload(payload: Mapping[str, JSONValue] | None) -> AggregatedMedia
return cast(AggregatedMediaItem, data)


def _normalize_identifier(value: JSONValue) -> str | None:
def _normalize_identifier(value: str | int | float | None) -> str | None:
"""Convert mixed identifier formats into a normalized string."""

if value is None:
Expand Down Expand Up @@ -376,19 +385,17 @@ async def _get_media_data(identifier: str) -> AggregatedMediaItem:
cache_keys.add(lookup_key)

plex_data = _extract_plex_metadata(data)
rating_key = _normalize_identifier(cast(JSONValue, plex_data.get("rating_key")))
rating_key = _normalize_identifier(plex_data.get("rating_key"))
if rating_key:
cache_keys.add(rating_key)
guid = _normalize_identifier(cast(JSONValue, plex_data.get("guid")))
guid = _normalize_identifier(plex_data.get("guid"))
if guid:
cache_keys.add(guid)

for source_key in ("imdb", "tmdb", "tvdb"):
source_value = data.get(source_key)
if isinstance(source_value, dict):
source_id = _normalize_identifier(
cast(JSONValue, source_value.get("id"))
)
source_id = _normalize_identifier(source_value.get("id"))
if source_id:
cache_keys.add(source_id)

Expand Down Expand Up @@ -654,7 +661,7 @@ async def play_media(
media = await _get_media_data(identifier)
plex_info = _extract_plex_metadata(media)
rating_key_value = plex_info.get("rating_key")
rating_key_normalized = _normalize_identifier(cast(JSONValue, rating_key_value))
rating_key_normalized = _normalize_identifier(rating_key_value)
if not rating_key_normalized:
raise ValueError("Media item is missing a Plex rating key")

Expand Down Expand Up @@ -736,9 +743,7 @@ async def search_media(
async def _prefetch(hit: models.ScoredPoint) -> None:
data = _flatten_payload(cast(Mapping[str, JSONValue] | None, hit.payload))
plex_info = _extract_plex_metadata(data)
rating_key = _normalize_identifier(
cast(JSONValue, plex_info.get("rating_key"))
)
rating_key = _normalize_identifier(plex_info.get("rating_key"))
if rating_key:
server.cache.set_payload(rating_key, cast(dict[str, JSONValue], data))
thumb = plex_info.get("thumb")
Expand Down Expand Up @@ -1361,7 +1366,7 @@ async def media_poster(
if not thumb:
raise ValueError("Poster not available")
thumb_str = str(thumb)
rating_key = _normalize_identifier(cast(JSONValue, plex_info.get("rating_key")))
rating_key = _normalize_identifier(plex_info.get("rating_key"))
if rating_key:
server.cache.set_poster(rating_key, thumb_str)
return thumb_str
Expand All @@ -1387,7 +1392,7 @@ async def media_background(
if not art:
raise ValueError("Background not available")
art_str = str(art)
rating_key = _normalize_identifier(cast(JSONValue, plex_info.get("rating_key")))
rating_key = _normalize_identifier(plex_info.get("rating_key"))
if rating_key:
server.cache.set_background(rating_key, art_str)
return art_str
Expand Down Expand Up @@ -1417,7 +1422,7 @@ async def rest_docs(request: Request) -> Response:
return get_swagger_ui_html(openapi_url="/openapi.json", title="MCP REST API")


def _build_openapi_schema() -> dict[str, Any]:
def _build_openapi_schema() -> dict[str, object]:
app = FastAPI()
for name, tool in server._tool_manager._tools.items():
request_model = _request_model(name, tool.fn)
Expand All @@ -1439,7 +1444,7 @@ async def _tool_stub(payload: request_model) -> None: # type: ignore[name-defin
annotation=request_model,
)
],
return_annotation=Any,
return_annotation=inspect.Signature.empty,
)

app.post(f"/rest/{name}")(_tool_stub)
Expand All @@ -1449,9 +1454,10 @@ async def _p_stub(**kwargs): # noqa: ARG001
_p_stub.__name__ = f"prompt_{name.replace('-', '_')}"
_p_stub.__doc__ = prompt.fn.__doc__
request_model = _request_model(name, prompt.fn)
prompt_signature = inspect.signature(prompt.fn)
if request_model is None:
_p_stub.__signature__ = inspect.signature(prompt.fn).replace(
return_annotation=Any
_p_stub.__signature__ = prompt_signature.replace(
return_annotation=inspect.Signature.empty
)
else:
_p_stub.__signature__ = inspect.Signature(
Expand All @@ -1462,7 +1468,7 @@ async def _p_stub(**kwargs): # noqa: ARG001
annotation=request_model,
)
],
return_annotation=Any,
return_annotation=inspect.Signature.empty,
)
app.post(f"/rest/prompt/{name}")(_p_stub)
for uri, resource in server._resource_manager._templates.items():
Expand All @@ -1472,7 +1478,7 @@ async def _r_stub(**kwargs): # noqa: ARG001
_r_stub.__name__ = f"resource_{path.replace('/', '_').replace('{', '').replace('}', '')}"
_r_stub.__doc__ = resource.fn.__doc__
_r_stub.__signature__ = inspect.signature(resource.fn).replace(
return_annotation=Any
return_annotation=inspect.Signature.empty
)
app.get(f"/rest/resource/{path}")(_r_stub)
return get_openapi(title="MCP REST API", version="1.0.0", routes=app.routes)
Expand All @@ -1492,7 +1498,9 @@ def _register_rest_endpoints() -> None:
def _register(path: str, method: str, handler: Callable, fn: Callable, name: str) -> None:
handler.__name__ = name
handler.__doc__ = fn.__doc__
handler.__signature__ = inspect.signature(fn).replace(return_annotation=Any)
handler.__signature__ = inspect.signature(fn).replace(
return_annotation=inspect.Signature.empty
)
server.custom_route(path, methods=[method])(handler)

for name, tool in server._tool_manager._tools.items():
Expand Down Expand Up @@ -1560,6 +1568,27 @@ async def _rest_resource(request: Request, _uri_template=uri, _resource=resource
_register_rest_endpoints()


@dataclass
class RunConfig:
"""Runtime configuration for FastMCP transport servers."""

host: str | None = None
port: int | None = None
path: str | None = None

def to_kwargs(self) -> dict[str, object]:
"""Return keyword arguments compatible with ``FastMCP.run``."""

kwargs: dict[str, object] = {}
if self.host is not None:
kwargs["host"] = self.host
if self.port is not None:
kwargs["port"] = self.port
if self.path:
kwargs["path"] = self.path
return kwargs


def main(argv: list[str] | None = None) -> None:
"""CLI entrypoint for running the MCP server."""
parser = argparse.ArgumentParser(description="Run the MCP server")
Expand Down Expand Up @@ -1616,16 +1645,19 @@ def main(argv: list[str] | None = None) -> None:
if transport == "stdio" and mount:
parser.error("--mount or MCP_MOUNT is not allowed when transport is stdio")

run_kwargs: dict[str, Any] = {}
run_config = RunConfig()
if transport != "stdio":
run_kwargs.update({"host": host, "port": port})
if host is not None:
run_config.host = host
if port is not None:
run_config.port = port
if mount:
run_kwargs["path"] = mount
run_config.path = mount

server.settings.dense_model = args.dense_model
server.settings.sparse_model = args.sparse_model

server.run(transport=transport, **run_kwargs)
server.run(transport=transport, **run_config.to_kwargs())


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "mcp-plex"
version = "1.0.22"
version = "1.0.23"

description = "Plex-Oriented Model Context Protocol Server"
requires-python = ">=3.11,<3.13"
Expand Down
42 changes: 42 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from mcp_plex import loader
from mcp_plex import server as server_module
from pydantic import ValidationError


@contextmanager
Expand Down Expand Up @@ -517,6 +518,47 @@ async def _noop() -> None:
assert module._request_model("noop", _noop) is None


def test_request_model_missing_annotation_uses_object():
module = importlib.import_module("mcp_plex.server")

async def _unannotated(foo): # type: ignore[no-untyped-def]
return foo

request_model = module._request_model("unannotated", _unannotated)
assert request_model is not None
field = request_model.model_fields["foo"]
assert field.annotation is object
with pytest.raises(ValidationError):
request_model()
instance = request_model(foo="value")
assert instance.foo == "value"


def test_normalize_identifier_scalar_inputs():
module = importlib.import_module("mcp_plex.server")

assert module._normalize_identifier(" value ") == "value"
assert module._normalize_identifier(123) == "123"
assert module._normalize_identifier(0.0) == "0.0"
assert module._normalize_identifier("") is None
assert module._normalize_identifier(None) is None


def test_run_config_to_kwargs():
module = importlib.import_module("mcp_plex.server")

config = module.RunConfig()
assert config.to_kwargs() == {}

config.host = "127.0.0.1"
config.port = 8080
assert config.to_kwargs() == {"host": "127.0.0.1", "port": 8080}

config.path = "/plex"
kwargs = config.to_kwargs()
assert kwargs["path"] == "/plex"


def test_find_records_handles_retrieve_error(monkeypatch):
with _load_server(monkeypatch) as module:
async def fail_retrieve(*args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.