Skip to content

Commit

Permalink
✨ upgrade to pydantic v2 (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
he0119 committed Feb 13, 2024
1 parent 974f5a8 commit 75db0ee
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 129 deletions.
6 changes: 4 additions & 2 deletions nonebot/adapters/red/adapter.py
Expand Up @@ -5,9 +5,11 @@

from nonebot.utils import escape_tag
from pydantic import ValidationError
from nonebot.compat import type_validate_python
from nonebot.drivers import Driver, Request, WebSocket, ForwardDriver
from nonebot.exception import ActionFailed, NetworkError, WebSocketClosed

from nonebot import get_plugin_config
from nonebot.adapters import Adapter as BaseAdapter

from .bot import Bot
Expand All @@ -31,7 +33,7 @@ class Adapter(BaseAdapter):
def __init__(self, driver: Driver, **kwargs: Any):
super().__init__(driver, **kwargs)
# 读取适配器所需的配置项
self.red_config: Config = Config.parse_obj(self.config)
self.red_config: Config = get_plugin_config(Config)
self._bots = self.red_config.red_bots
if self.red_config.red_auto_detect and not self._bots:
try:
Expand Down Expand Up @@ -159,7 +161,7 @@ def _handle_event(event_data: Any, target: Type[Event]):

def _handle_message(message: dict):
try:
_data = MessageModel.parse_obj(message)
_data = type_validate_python(MessageModel, message)
except ValidationError as e:
log(
"WARNING",
Expand Down
17 changes: 10 additions & 7 deletions nonebot/adapters/red/bot.py
Expand Up @@ -5,6 +5,7 @@
from typing import Any, List, Tuple, Union, Optional

from nonebot.message import handle_event
from nonebot.compat import type_validate_python

from nonebot.adapters import Bot as BaseBot
from nonebot.adapters import Adapter as BaseAdapter
Expand Down Expand Up @@ -185,7 +186,7 @@ async def send_message(
target=str(target),
elements=element_data,
)
return MessageModel.parse_obj(resp)
return type_validate_python(MessageModel, resp)

async def send_friend_message(
self,
Expand Down Expand Up @@ -242,22 +243,22 @@ async def send(
target=peerUin,
elements=element_data,
)
return MessageModel.parse_obj(resp)
return type_validate_python(MessageModel, resp)

async def get_self_profile(self) -> Profile:
"""获取登录账号自己的资料"""
resp = await self.call_api("get_self_profile")
return Profile.parse_obj(resp)
return type_validate_python(Profile, resp)

async def get_friends(self) -> List[Profile]:
"""获取登录账号所有好友的资料"""
resp = await self.call_api("get_friends")
return [Profile.parse_obj(data) for data in resp]
return [type_validate_python(Profile, data) for data in resp]

async def get_groups(self) -> List[Group]:
"""获取登录账号所有群组的资料"""
resp = await self.call_api("get_groups")
return [Group.parse_obj(data) for data in resp]
return [type_validate_python(Group, data) for data in resp]

async def mute_member(
self, group: int, *members: int, duration: Union[int, timedelta] = 60
Expand Down Expand Up @@ -342,7 +343,7 @@ async def get_members(self, group: int, size: int = 20) -> List[Member]:
size: 拉取多少个成员资料
"""
resp = await self.call_api("get_members", group=group, size=size)
return [Member.parse_obj(data["detail"]) for data in resp]
return [type_validate_python(Member, data["detail"]) for data in resp]

async def fetch(self, ms: BaseMessageSegment):
"""获取媒体消息段的二进制数据
Expand Down Expand Up @@ -397,7 +398,9 @@ async def upload(self, file: bytes) -> UploadResponse:
file: 上传的资源数据
"""
log("WARING", "This API is not suggest for user usage")
return UploadResponse.parse_obj(await self.call_api("upload", file=file))
return type_validate_python(
UploadResponse, await self.call_api("upload", file=file)
)

async def recall_message(
self,
Expand Down
22 changes: 22 additions & 0 deletions nonebot/adapters/red/compat.py
@@ -0,0 +1,22 @@
from typing import Literal, overload

from nonebot.compat import PYDANTIC_V2

__all__ = ("model_validator",)


if PYDANTIC_V2:
from pydantic import model_validator as model_validator
else:
from pydantic import root_validator

@overload
def model_validator(*, mode: Literal["before"]):
...

@overload
def model_validator(*, mode: Literal["after"]):
...

def model_validator(*, mode: Literal["before", "after"]):
return root_validator(pre=mode == "before", allow_reuse=True)
11 changes: 6 additions & 5 deletions nonebot/adapters/red/config.py
Expand Up @@ -3,7 +3,8 @@
from typing import Dict, List

from yarl import URL
from pydantic import Extra, Field, BaseModel
from pydantic import Field, BaseModel
from nonebot.compat import type_validate_python


class BotInfo(BaseModel):
Expand All @@ -24,16 +25,16 @@ class Server(BaseModel):
host: str = Field(default="localhost", alias="listen")


class Servers(BaseModel, extra=Extra.ignore):
class Servers(BaseModel):
servers: List[Server] = Field(default_factory=list)
enable: bool = True


class ChronocatConfig(Servers, extra=Extra.ignore):
class ChronocatConfig(Servers):
overrides: Dict[str, Servers] = Field(default_factory=dict)


class Config(BaseModel, extra=Extra.ignore):
class Config(BaseModel):
red_bots: List[BotInfo] = Field(default_factory=list)
"""bot 配置"""

Expand All @@ -53,7 +54,7 @@ def get_config() -> List[BotInfo]:
if not config.exists():
return []
with open(config, encoding="utf-8") as f:
chrono_config = ChronocatConfig.parse_obj(yaml.safe_load(f))
chrono_config = type_validate_python(ChronocatConfig, yaml.safe_load(f))
base_config = next(
(s for s in chrono_config.servers if s.type == "red" and s.enable), None
)
Expand Down
19 changes: 13 additions & 6 deletions nonebot/adapters/red/event.py
Expand Up @@ -5,11 +5,12 @@
from datetime import datetime, timedelta

from nonebot.utils import escape_tag
from pydantic.class_validators import root_validator
from nonebot.compat import model_dump, type_validate_python

from nonebot.adapters import Event as BaseEvent

from .message import Message
from .compat import model_validator
from .api.model import Message as MessageModel
from .api.model import MsgType, ChatType, ReplyElement, ShutUpTarget

Expand All @@ -27,7 +28,7 @@ def get_event_name(self) -> str:

@override
def get_event_description(self) -> str:
return escape_tag(str(self.dict()))
return escape_tag(str(model_dump(self)))

@override
def get_message(self):
Expand All @@ -51,7 +52,7 @@ def convert(cls, obj: Any):
子类可根据需要重写此方法
"""
return cls.parse_obj(obj)
return type_validate_python(cls, obj)


class MessageEvent(Event, MessageModel):
Expand Down Expand Up @@ -85,7 +86,7 @@ def get_event_name(self) -> str:
def get_message(self) -> Message:
return self.message

@root_validator(pre=True, allow_reuse=True)
@model_validator(mode="before")
def check_message(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "elements" in values:
values["message"] = Message.from_red_message(
Expand Down Expand Up @@ -296,7 +297,11 @@ def convert(cls, obj: Any):
"peerUid": obj.peerUid,
"peerUin": obj.peerUin,
}
if obj.elements[0].grayTipElement and obj.elements[0].grayTipElement.xmlElement and obj.elements[0].grayTipElement.xmlElement.content: # type: ignore # noqa: E501
if (
obj.elements[0].grayTipElement
and obj.elements[0].grayTipElement.xmlElement
and obj.elements[0].grayTipElement.xmlElement.content
): # type: ignore # noqa: E501
# fmt: off
if not (mat := legacy_invite_message.search(obj.elements[0].grayTipElement.xmlElement.content)): # type: ignore # noqa: E501
raise ValueError("Invalid legacy invite message.")
Expand All @@ -306,7 +311,9 @@ def convert(cls, obj: Any):
else:
params["memberUid"] = obj.elements[0].grayTipElement.groupElement.memberUin # type: ignore # noqa: E501
params["operatorUid"] = obj.elements[0].grayTipElement.groupElement.adminUin # type: ignore # noqa: E501
params["memberName"] = obj.elements[0].grayTipElement.groupElement.memberNick # type: ignore # noqa: E501
params["memberName"] = obj.elements[
0
].grayTipElement.groupElement.memberNick # type: ignore # noqa: E501
return cls(**params)


Expand Down

0 comments on commit 75db0ee

Please sign in to comment.