Skip to content

Commit

Permalink
🐛 Fix: State ForwardRef 检测错误 (#2698)
Browse files Browse the repository at this point in the history
  • Loading branch information
yanyongyu committed May 9, 2024
1 parent 41b59cf commit 723fa4b
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 4 deletions.
12 changes: 10 additions & 2 deletions nonebot/internal/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@

from nonebot.dependencies import Param, Dependent
from nonebot.dependencies.utils import check_field_type
from nonebot.typing import T_State, T_Handler, T_DependencyCache
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
from nonebot.typing import (
_STATE_FLAG,
T_State,
T_Handler,
T_DependencyCache,
origin_is_annotated,
)
from nonebot.utils import (
get_name,
run_sync,
Expand Down Expand Up @@ -349,7 +355,9 @@ def _check_param(
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
) -> Optional[Self]:
# param type is T_State
if param.annotation is T_State:
if origin_is_annotated(
get_origin(param.annotation)
) and _STATE_FLAG in get_args(param.annotation):
return cls()
# legacy: param is named "state" and has no type annotation
elif param.annotation == param.empty and param.name == "state":
Expand Down
10 changes: 9 additions & 1 deletion nonebot/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,15 @@ def evaluate_forwardref(


# state
T_State: TypeAlias = dict[t.Any, t.Any]
# use annotated flag to avoid ForwardRef recreate generic type (py >= 3.11)
class StateFlag:
def __repr__(self) -> str:
return "StateFlag()"


_STATE_FLAG = StateFlag()

T_State: TypeAlias = t.Annotated[dict[t.Any, t.Any], _STATE_FLAG]
"""事件处理状态 State 类型"""

_DependentCallable: TypeAlias = t.Union[
Expand Down
4 changes: 4 additions & 0 deletions tests/plugins/param/param_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ async def get_bot(b: Bot) -> Bot:
return b


async def postpone_bot(b: "Bot") -> Bot:
return b


async def legacy_bot(bot):
return bot

Expand Down
4 changes: 4 additions & 0 deletions tests/plugins/param/param_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ async def event(e: Event) -> Event:
return e


async def postpone_event(e: "Event") -> Event:
return e


async def legacy_event(event):
return event

Expand Down
6 changes: 5 additions & 1 deletion tests/plugins/param/param_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ async def matcher(m: Matcher) -> Matcher:
return m


async def postpone_matcher(m: "Matcher") -> Matcher:
return m


async def legacy_matcher(matcher):
return matcher

Expand All @@ -27,7 +31,7 @@ class BarMatcher(Matcher): ...


async def union_matcher(
m: Union[FooMatcher, BarMatcher]
m: Union[FooMatcher, BarMatcher],
) -> Union[FooMatcher, BarMatcher]:
return m

Expand Down
4 changes: 4 additions & 0 deletions tests/plugins/param/param_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ async def state(x: T_State) -> T_State:
return x


async def postpone_state(x: "T_State") -> T_State:
return x


async def legacy_state(state):
return state

Expand Down
21 changes: 21 additions & 0 deletions tests/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ async def test_bot(app: App):
union_bot,
legacy_bot,
generic_bot,
postpone_bot,
not_legacy_bot,
generic_bot_none,
)
Expand All @@ -138,6 +139,11 @@ async def test_bot(app: App):
ctx.pass_params(bot=bot)
ctx.should_return(bot)

async with app.test_dependent(postpone_bot, allow_types=[BotParam]) as ctx:
bot = ctx.create_bot()
ctx.pass_params(bot=bot)
ctx.should_return(bot)

async with app.test_dependent(legacy_bot, allow_types=[BotParam]) as ctx:
bot = ctx.create_bot()
ctx.pass_params(bot=bot)
Expand Down Expand Up @@ -188,6 +194,7 @@ async def test_event(app: App):
legacy_event,
event_message,
generic_event,
postpone_event,
event_plain_text,
not_legacy_event,
generic_event_none,
Expand All @@ -201,6 +208,10 @@ async def test_event(app: App):
ctx.pass_params(event=fake_event)
ctx.should_return(fake_event)

async with app.test_dependent(postpone_event, allow_types=[EventParam]) as ctx:
ctx.pass_params(event=fake_event)
ctx.should_return(fake_event)

async with app.test_dependent(legacy_event, allow_types=[EventParam]) as ctx:
ctx.pass_params(event=fake_event)
ctx.should_return(fake_event)
Expand Down Expand Up @@ -273,6 +284,7 @@ async def test_state(app: App):
legacy_state,
command_start,
regex_matched,
postpone_state,
not_legacy_state,
command_whitespace,
shell_command_args,
Expand Down Expand Up @@ -302,6 +314,10 @@ async def test_state(app: App):
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state)

async with app.test_dependent(postpone_state, allow_types=[StateParam]) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state)

async with app.test_dependent(legacy_state, allow_types=[StateParam]) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state)
Expand Down Expand Up @@ -414,6 +430,7 @@ async def test_matcher(app: App):
union_matcher,
legacy_matcher,
generic_matcher,
postpone_matcher,
not_legacy_matcher,
generic_matcher_none,
)
Expand All @@ -425,6 +442,10 @@ async def test_matcher(app: App):
ctx.pass_params(matcher=fake_matcher)
ctx.should_return(fake_matcher)

async with app.test_dependent(postpone_matcher, allow_types=[MatcherParam]) as ctx:
ctx.pass_params(matcher=fake_matcher)
ctx.should_return(fake_matcher)

async with app.test_dependent(legacy_matcher, allow_types=[MatcherParam]) as ctx:
ctx.pass_params(matcher=fake_matcher)
ctx.should_return(fake_matcher)
Expand Down

0 comments on commit 723fa4b

Please sign in to comment.