Skip to content

Commit 57fe5ba

Browse files
committed
fix: wrapped annotations handling in func_metadata
Use the original function's __globals__ for type hint resolution when dealing with wrapped functions. This ensures that any type hints defined in the original function's module are correctly resolved. This also includes adding common typing names for resiliency. Fixes #1391
1 parent 814c9c0 commit 57fe5ba

File tree

4 files changed

+62
-3
lines changed

4 files changed

+62
-3
lines changed

src/mcp/server/fastmcp/utilities/func_metadata.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,18 +470,27 @@ def try_eval_type(value: Any, globalns: dict[str, Any], localns: dict[str, Any])
470470

471471
def _get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
472472
"""Get function signature while evaluating forward references"""
473+
473474
signature = inspect.signature(call)
474-
globalns = getattr(call, "__globals__", {})
475+
try:
476+
type_hints = get_type_hints(call, include_extras=True)
477+
except TypeError:
478+
# get_type_hints doesn't handle callable objects.
479+
type_hints = {}
480+
481+
def resolve_annotation(name: str, annotation: Any) -> Any:
482+
return type_hints.get(name, annotation) if isinstance(annotation, str) else annotation
483+
475484
typed_params = [
476485
inspect.Parameter(
477486
name=param.name,
478487
kind=param.kind,
479488
default=param.default,
480-
annotation=_get_typed_annotation(param.annotation, globalns),
489+
annotation=resolve_annotation(param.name, param.annotation),
481490
)
482491
for param in signature.parameters.values()
483492
]
484-
typed_return = _get_typed_annotation(signature.return_annotation, globalns)
493+
typed_return = resolve_annotation("return", signature.return_annotation)
485494
typed_signature = inspect.Signature(typed_params, return_annotation=typed_return)
486495
return typed_signature
487496

tests/server/fastmcp/test_func_metadata.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from mcp.server.fastmcp.utilities.func_metadata import func_metadata
1616

17+
from .test_wrapped import wrapped_function
18+
1719

1820
class SomeInputModelA(BaseModel):
1921
pass
@@ -1094,3 +1096,20 @@ def func_with_reserved_json(
10941096
assert result["json"] == {"nested": "data"}
10951097
assert result["model_dump"] == [1, 2, 3]
10961098
assert result["normal"] == "plain string"
1099+
1100+
1101+
@pytest.mark.anyio
1102+
async def test_wrapped_annotations_func() -> None:
1103+
"""Test that func_metadata works with wrapped annotations functions."""
1104+
meta = func_metadata(wrapped_function)
1105+
1106+
result = await meta.call_fn_with_arg_validation(
1107+
wrapped_function,
1108+
fn_is_async=False,
1109+
arguments_to_validate={
1110+
"literal": "test",
1111+
},
1112+
arguments_to_pass_directly=None,
1113+
)
1114+
1115+
assert result == "test"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from collections.abc import Callable
2+
from functools import wraps
3+
from typing import TypeVar
4+
5+
from typing_extensions import ParamSpec
6+
7+
P = ParamSpec("P")
8+
R = TypeVar("R")
9+
10+
11+
def instrument(func: Callable[P, R]) -> Callable[P, R]:
12+
"""
13+
Example decorator that logs before/after the call
14+
while preserving the original function's type signature.
15+
"""
16+
17+
@wraps(func)
18+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
19+
return func(*args, **kwargs)
20+
21+
return wrapper
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from __future__ import annotations
2+
3+
from typing import Literal
4+
5+
from .test_instrument import instrument
6+
7+
8+
@instrument
9+
def wrapped_function(literal: Literal["test"] | None = None) -> Literal["test"] | None:
10+
return literal

0 commit comments

Comments
 (0)