Skip to content

Commit 11b711f

Browse files
committed
refactor: refactored msg_to_toml and toml_to_message into Message methods
1 parent c237dde commit 11b711f

File tree

2 files changed

+59
-58
lines changed

2 files changed

+59
-58
lines changed

gptme/message.py

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44
import textwrap
55
from datetime import datetime
6-
from typing import Literal
6+
from typing import Literal, Self
77

88
import tomlkit
99
from rich import print
@@ -40,6 +40,10 @@ def __init__(
4040
# Wether this message should be printed on execution (will still print on resume, unlike hide)
4141
self.quiet = quiet
4242

43+
def __repr__(self):
44+
content = textwrap.shorten(self.content, 20, placeholder="...")
45+
return f"<Message role={self.role} content={content}>"
46+
4347
def to_dict(self, keys=None):
4448
"""Return a dict representation of the message, serializable to JSON."""
4549
d = {
@@ -54,9 +58,48 @@ def to_dict(self, keys=None):
5458
def format(self, oneline: bool = False, highlight: bool = False) -> str:
5559
return format_msgs([self], oneline=oneline, highlight=highlight)[0]
5660

57-
def __repr__(self):
58-
content = textwrap.shorten(self.content, 20, placeholder="...")
59-
return f"<Message role={self.role} content={content}>"
61+
def to_toml(self) -> str:
62+
"""Converts a message to a TOML string, for easy editing by hand in editor to then be parsed back."""
63+
flags = []
64+
if self.pinned:
65+
flags.append("pinned")
66+
if self.hide:
67+
flags.append("hide")
68+
if self.quiet:
69+
flags.append("quiet")
70+
flags_toml = "\n".join(f"{flag} = true" for flag in flags)
71+
72+
# doublequotes need to be escaped
73+
content = self.content.replace('"', '\\"')
74+
return f'''[message]
75+
role = "{self.role}"
76+
content = """
77+
{content}
78+
"""
79+
timestamp = "{self.timestamp.isoformat()}"
80+
{flags_toml}
81+
'''
82+
83+
@classmethod
84+
def from_toml(cls, toml: str) -> Self:
85+
"""
86+
Converts a TOML string to a message.
87+
88+
The string can be a single [[message]].
89+
"""
90+
91+
t = tomlkit.parse(toml)
92+
assert "message" in t and isinstance(t["message"], dict)
93+
msg: dict = t["message"] # type: ignore
94+
95+
return cls(
96+
msg["role"],
97+
msg["content"],
98+
pinned=msg.get("pinned", False),
99+
hide=msg.get("hide", False),
100+
quiet=msg.get("quiet", False),
101+
timestamp=datetime.fromisoformat(msg["timestamp"]),
102+
)
60103

61104
def get_codeblocks(self, content=False) -> list[str]:
62105
"""
@@ -154,58 +197,15 @@ def print_msg(
154197
)
155198

156199

157-
def msg_to_toml(msg: Message) -> str:
158-
"""Converts a message to a TOML string, for easy editing by hand in editor to then be parsed back."""
159-
# TODO: escape msg.content
160-
flags = []
161-
if msg.pinned:
162-
flags.append("pinned")
163-
if msg.hide:
164-
flags.append("hide")
165-
if msg.quiet:
166-
flags.append("quiet")
167-
168-
# doublequotes need to be escaped
169-
content = msg.content.replace('"', '\\"')
170-
return f'''[message]
171-
role = "{msg.role}"
172-
content = """
173-
{content}
174-
"""
175-
timestamp = "{msg.timestamp.isoformat()}"
176-
'''
177-
178-
179200
def msgs_to_toml(msgs: list[Message]) -> str:
180201
"""Converts a list of messages to a TOML string, for easy editing by hand in editor to then be parsed back."""
181202
t = ""
182203
for msg in msgs:
183-
t += msg_to_toml(msg).replace("[message]", "[[messages]]") + "\n\n"
204+
t += msg.to_toml().replace("[message]", "[[messages]]") + "\n\n"
184205

185206
return t
186207

187208

188-
def toml_to_msg(toml: str) -> Message:
189-
"""
190-
Converts a TOML string to a message.
191-
192-
The string can be a single [[message]].
193-
"""
194-
195-
t = tomlkit.parse(toml)
196-
assert "message" in t and isinstance(t["message"], dict)
197-
msg: dict = t["message"] # type: ignore
198-
199-
return Message(
200-
msg["role"],
201-
msg["content"],
202-
pinned=msg.get("pinned", False),
203-
hide=msg.get("hide", False),
204-
quiet=msg.get("quiet", False),
205-
timestamp=datetime.fromisoformat(msg["timestamp"]),
206-
)
207-
208-
209209
def toml_to_msgs(toml: str) -> list[Message]:
210210
"""
211211
Converts a TOML string to a list of messages.

tests/test_message.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,25 @@
1-
from gptme.message import (
2-
Message,
3-
msg_to_toml,
4-
msgs_to_toml,
5-
toml_to_msg,
6-
toml_to_msgs,
7-
)
1+
from gptme.message import Message, msgs_to_toml, toml_to_msgs
82

93

104
def test_toml():
5+
# single message, check escaping
116
msg = Message(
127
"system",
138
'''Hello world!
149
"""Difficult to handle string"""
1510
''',
1611
)
17-
t = msg_to_toml(msg)
12+
t = msg.to_toml()
1813
print(t)
19-
m = toml_to_msg(t)
14+
m = Message.from_toml(t)
2015
print(m)
2116
assert msg.content == m.content
2217
assert msg.role == m.role
2318
assert msg.timestamp.date() == m.timestamp.date()
2419
assert msg.timestamp.timetuple() == m.timestamp.timetuple()
2520

26-
msg2 = Message("user", "Hello computer!")
21+
# multiple messages
22+
msg2 = Message("user", "Hello computer!", pinned=True, hide=True, quiet=True)
2723
ts = msgs_to_toml([msg, msg2])
2824
print(ts)
2925
ms = toml_to_msgs(ts)
@@ -34,6 +30,11 @@ def test_toml():
3430
assert ms[0].content == msg.content
3531
assert ms[1].content == msg2.content
3632

33+
# check flags
34+
assert ms[1].pinned == msg2.pinned
35+
assert ms[1].hide == msg2.hide
36+
assert ms[1].quiet == msg2.quiet
37+
3738

3839
def test_get_codeblocks():
3940
# single codeblock only

0 commit comments

Comments
 (0)