diff --git a/src/mcp/server/fastmcp/utilities/func_metadata.py b/src/mcp/server/fastmcp/utilities/func_metadata.py index 3289a5aa6..52fdb981c 100644 --- a/src/mcp/server/fastmcp/utilities/func_metadata.py +++ b/src/mcp/server/fastmcp/utilities/func_metadata.py @@ -470,18 +470,27 @@ def try_eval_type(value: Any, globalns: dict[str, Any], localns: dict[str, Any]) def _get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: """Get function signature while evaluating forward references""" + signature = inspect.signature(call) - globalns = getattr(call, "__globals__", {}) + try: + type_hints = get_type_hints(call, include_extras=True) + except TypeError: + # get_type_hints doesn't handle callable objects. + type_hints = {} + + def resolve_annotation(name: str, annotation: Any) -> Any: + return type_hints.get(name, annotation) if isinstance(annotation, str) else annotation + typed_params = [ inspect.Parameter( name=param.name, kind=param.kind, default=param.default, - annotation=_get_typed_annotation(param.annotation, globalns), + annotation=resolve_annotation(param.name, param.annotation), ) for param in signature.parameters.values() ] - typed_return = _get_typed_annotation(signature.return_annotation, globalns) + typed_return = resolve_annotation("return", signature.return_annotation) typed_signature = inspect.Signature(typed_params, return_annotation=typed_return) return typed_signature diff --git a/tests/server/fastmcp/test_func_metadata.py b/tests/server/fastmcp/test_func_metadata.py index 830cf816b..63d93fbee 100644 --- a/tests/server/fastmcp/test_func_metadata.py +++ b/tests/server/fastmcp/test_func_metadata.py @@ -14,6 +14,8 @@ from mcp.server.fastmcp.utilities.func_metadata import func_metadata +from .test_wrapped import wrapped_function + class SomeInputModelA(BaseModel): pass @@ -1094,3 +1096,20 @@ def func_with_reserved_json( assert result["json"] == {"nested": "data"} assert result["model_dump"] == [1, 2, 3] assert result["normal"] == "plain string" + + +@pytest.mark.anyio +async def test_wrapped_annotations_func() -> None: + """Test that func_metadata works with wrapped annotations functions.""" + meta = func_metadata(wrapped_function) + + result = await meta.call_fn_with_arg_validation( + wrapped_function, + fn_is_async=False, + arguments_to_validate={ + "literal": "test", + }, + arguments_to_pass_directly=None, + ) + + assert result == "test" diff --git a/tests/server/fastmcp/test_instrument.py b/tests/server/fastmcp/test_instrument.py new file mode 100644 index 000000000..65090b8d7 --- /dev/null +++ b/tests/server/fastmcp/test_instrument.py @@ -0,0 +1,21 @@ +from collections.abc import Callable +from functools import wraps +from typing import TypeVar + +from typing_extensions import ParamSpec + +P = ParamSpec("P") +R = TypeVar("R") + + +def instrument(func: Callable[P, R]) -> Callable[P, R]: + """ + Example decorator that logs before/after the call + while preserving the original function's type signature. + """ + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + return func(*args, **kwargs) + + return wrapper diff --git a/tests/server/fastmcp/test_wrapped.py b/tests/server/fastmcp/test_wrapped.py new file mode 100644 index 000000000..6b1fe297c --- /dev/null +++ b/tests/server/fastmcp/test_wrapped.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from typing import Literal + +from .test_instrument import instrument + + +@instrument +def wrapped_function(literal: Literal["test"] | None = None) -> Literal["test"] | None: + return literal