Skip to content

Commit

Permalink
✨ support message rollback
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Apr 24, 2024
1 parent 1e73aa5 commit e064671
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 60 deletions.
118 changes: 59 additions & 59 deletions nonebot/adapters/satori/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
from base64 import b64encode
from dataclasses import InitVar, field, dataclass
from typing_extensions import Self, Required, NotRequired, override
from typing_extensions import Self, NotRequired, override
from typing import Any, Dict, List, Type, Tuple, Union, Iterable, Optional, TypedDict

from nonebot.adapters import Message as BaseMessage
Expand All @@ -12,7 +12,14 @@
from .element import Element, parse, escape, param_case


@dataclass
class MessageSegment(BaseMessageSegment["Message"]):
_children: "Message" = field(init=False, default_factory=lambda: Message())

@property
def children(self) -> "Message":
return self._children

def __str__(self) -> str:
def _attr(key: str, value: Any):
if value is None:
Expand All @@ -27,7 +34,16 @@ def _attr(key: str, value: Any):
attrs = "".join(_attr(k, v) for k, v in self.data.items())
if self.type == "text" and "text" in self.data:
return escape(self.data["text"])
return f"<{self.type}{attrs}/>"
inner = "".join(str(c) for c in self._children)
if not self._children:
return f"<{self.type}{attrs}/>"
return f"<{self.type}{attrs}>{inner}</{self.type}>"

def __call__(self, *segments: Union[str, Iterable["MessageSegment"], "MessageSegment"]) -> Self:
for seg in segments:
self._children.__iadd__(seg)
self._children.__merge_text__()
return self

@classmethod
@override
Expand Down Expand Up @@ -266,7 +282,7 @@ def input_button(text: str, display: Optional[str] = None, theme: Optional[str]

@staticmethod
def raw(content: str) -> "Raw":
return Raw("raw", {"raw": content})
return Raw("raw", {"text": content})

@staticmethod
def br() -> "Br":
Expand Down Expand Up @@ -341,7 +357,7 @@ def is_text(self) -> bool:


class RawData(TypedDict):
raw: str
text: str


@dataclass
Expand All @@ -350,7 +366,7 @@ class Raw(MessageSegment):

@override
def __str__(self) -> str:
return self.data["raw"]
return self.data["text"]

@override
def is_text(self) -> bool:
Expand Down Expand Up @@ -471,8 +487,11 @@ class Link(MessageSegment):

@override
def __str__(self):
inner = "".join(str(c) for c in self._children)
if "display" in self.data:
return f'<a href="{escape(self.data["text"])}">{escape(self.data["display"])}</a>'
return f'<a href="{escape(self.data["text"])}">{escape(self.data["display"])}{inner}</a>'
if inner:
return f'<a href="{escape(self.data["text"])}">{inner}</a>'
return f'<a href="{escape(self.data["text"])}"/>'

@override
Expand Down Expand Up @@ -570,24 +589,15 @@ def is_text(self) -> bool:
class RenderMessageData(TypedDict):
id: NotRequired[str]
forward: NotRequired[bool]
content: NotRequired["Message"]


@dataclass
class RenderMessage(MessageSegment):
data: RenderMessageData = field(default_factory=dict) # type: ignore

@override
def __str__(self):
attr = []
if "id" in self.data:
attr.append(f' id="{escape(self.data["id"])}"')
if self.data.get("forward"):
attr.append(" forward")
if "content" not in self.data:
return f'<{self.type}{"".join(attr)} />'
else:
return f'<{self.type}{"".join(attr)}>{self.data["content"]}</{self.type}>'
@property
def content(self) -> Optional["Message"]:
return self._children or None


class AuthorData(TypedDict):
Expand Down Expand Up @@ -625,37 +635,17 @@ def __str__(self):
attr.append(f'text="{escape(self.data["text"])}"') # type: ignore
if "theme" in self.data:
attr.append(f'theme="{escape(self.data["theme"])}"')
inner = "".join(str(c) for c in self._children)
if "display" in self.data:
return f'<button {" ".join(attr)}>{escape(self.data["display"])}</button>'
return f'<button {" ".join(attr)}>{escape(self.data["display"])}{inner}</button>'
if inner:
return f'<button {" ".join(attr)}>{inner}</button>'
return f'<button {" ".join(attr)} />'


class CustomData(TypedDict, total=False):
_children: Required[List["MessageSegment"]]


@dataclass
class Custom(MessageSegment):
data: CustomData = field(default_factory=dict) # type: ignore

def __str__(self) -> str:
def _attr(key: str, value: Any):
if value is None:
return ""
key = param_case(key)
if value is True:
return f" {key}"
if value is False:
return f" no-{key}"
return f' {key}="{escape(str(value), True)}"'

attrs = "".join(_attr(k, v) for k, v in self.data.items() if k != "_children")
if self.type == "text" and "text" in self.data:
return escape(self.data["text"])
inner = "".join(str(c) for c in self.data["_children"])
if not inner:
return f"<{self.type}{attrs}/>"
return f"<{self.type}{attrs}>{inner}</{self.type}>"
pass


ELEMENT_TYPE_MAP = {
Expand Down Expand Up @@ -699,15 +689,21 @@ def handle(element: Element, upper_styles: Optional[List[str]] = None):
tag = element.tag()
if tag in ELEMENT_TYPE_MAP:
seg_cls, seg_type = ELEMENT_TYPE_MAP[tag]
yield seg_cls(seg_type, element.attrs.copy())
yield seg_cls(seg_type, element.attrs.copy())(
*(handle(child, [*(upper_styles or [])]) for child in element.children)
)
elif tag in ("a", "link"):
if element.children:
yield Link("link", {"text": element.attrs["href"], "display": element.children[0].attrs["text"]})
yield Link("link", {"text": element.attrs["href"], "display": element.children[0].attrs["text"]})(
*(handle(child, [*(upper_styles or [])]) for child in element.children[1:])
)
else:
yield Link("link", {"text": element.attrs["href"]})
elif tag == "button":
if element.children:
yield Button("button", {"display": element.children[0].attrs["text"], **element.attrs}) # type: ignore
yield Button("button", {"display": element.children[0].attrs["text"], **element.attrs})( # type: ignore
*(handle(child, [*(upper_styles or [])]) for child in element.children[1:]),
)
else:
yield Button("button", {**element.attrs}) # type: ignore
elif tag in STYLE_TYPE_MAP:
Expand All @@ -728,15 +724,13 @@ def handle(element: Element, upper_styles: Optional[List[str]] = None):
elif tag in ("br", "newline"):
yield Br("br", {"text": "\n"})
elif tag in ("message", "quote"):
data = element.attrs.copy()
if element.children:
data["content"] = Message.from_satori_element(element.children)
yield RenderMessage(tag, data) # type: ignore
yield RenderMessage(tag, element.attrs.copy())( # type: ignore
*(handle(child, [*(upper_styles or [])]) for child in element.children),
)
else:
custom = Custom(element.tag(), {**element.attrs, "_children": []}) # type: ignore
for child in element.children:
custom.data["_children"].extend(handle(child))
yield custom
yield Custom(tag, element.attrs.copy())(
*(handle(child, [*(upper_styles or [])]) for child in element.children)
)


class Message(BaseMessage[MessageSegment]):
Expand All @@ -754,8 +748,7 @@ def __add__(self, other: Union[str, MessageSegment, Iterable[MessageSegment]]) -
@override
def __radd__(self, other: Union[str, MessageSegment, Iterable[MessageSegment]]) -> "Message":
result = self.__class__(MessageSegment.text(other) if isinstance(other, str) else other)
result = result + self
return result.__merge_text__()
return result + self

@staticmethod
@override
Expand All @@ -773,14 +766,21 @@ def from_satori_element(cls, elements: List[Element]) -> "Message":

@override
def extract_plain_text(self) -> str:
return "".join(seg.data["text"] for seg in self if seg.is_text())
text = str(self)
return re.sub(r"<[^>]+>", "", text)

def query(self, type_: str):
for seg in self:
if seg.type == type_:
yield seg
yield from seg.children.query(type_)

def __merge_text__(self) -> Self:
if not self:
return self
result = []
last = self[0]
for seg in self[1:]:
last = list.__getitem__(self, 0)
for seg in list.__getitem__(self, slice(1, None)):
if last.type == "text" and seg.type == "text":
assert isinstance(last, Text)
_len = len(last.data["text"])
Expand Down
16 changes: 15 additions & 1 deletion tests/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_message():
)
assert Message.from_satori_element(
parse("<chronocat:face id='265' name='[辣眼睛]' platform='chronocat'/>")
)[0].data == {"id": "265", "name": "[辣眼睛]", "platform": "chronocat", "_children": []}
)[0].data == {"id": "265", "name": "[辣眼睛]", "platform": "chronocat"}


@pytest.mark.asyncio
Expand All @@ -43,3 +43,17 @@ async def test_message_rich_expr():
assert str(msg4) == '<b>123<i>456</i></b><img src="url"/><b><i>789</i>abc</b>'
assert Message.from_satori_element(parse(str(msg4))) == msg4
assert msg4.extract_plain_text() == "123456789abc"


def test_message_fallback():
code = """\
<video src="http://aa.com/a.mp4">
当前平台不支持发送视频,请在
<a href="http://aa.com/a.mp4">这里</a>
观看视频!
</video>
"""
msg = Message.from_satori_element(parse(code))
assert str(msg[0].children) == '当前平台不支持发送视频,请在<a href="http://aa.com/a.mp4">这里</a>观看视频!'
assert msg.extract_plain_text() == "当前平台不支持发送视频,请在这里观看视频!"
assert list(msg.query("link"))[0].data["text"] == "http://aa.com/a.mp4"

0 comments on commit e064671

Please sign in to comment.