Skip to content

Commit d0a2245

Browse files
committed
refactor: refactored prompts and get_codeblock
1 parent 32df015 commit d0a2245

File tree

10 files changed

+238
-142
lines changed

10 files changed

+238
-142
lines changed

gptme/cli.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
to do so, it needs to be able to store and query past conversations in a database.
2121
"""
2222
# The above may be used as a prompt for the agent.
23-
2423
import atexit
2524
import errno
2625
import importlib.metadata
@@ -44,7 +43,7 @@
4443
from .llm import init_llm, reply
4544
from .logmanager import LogManager, _conversations
4645
from .message import Message
47-
from .prompts import initial_prompt_single_message
46+
from .prompts import get_prompt
4847
from .tabcomplete import register_tabcomplete
4948
from .tools import execute_msg, init_tools
5049
from .util import epoch_to_age, generate_unique_name
@@ -166,29 +165,27 @@ def main(
166165
if no_confirm:
167166
logger.warning("Skipping all confirmation prompts.")
168167

169-
if prompt_system in ["full", "short"]:
170-
promptmsgs = [initial_prompt_single_message(short=prompt_system == "short")]
171-
else:
172-
promptmsgs = [Message("system", prompt_system)]
173-
174-
# we need to run this before checking stdin, since the interactive doesn't work with the switch back to interactive mode
175-
logfile = get_logfile(
176-
name, interactive=(not prompts and interactive) and sys.stdin.isatty()
177-
)
178-
print(f"Using logdir {logfile.parent}")
179-
log = LogManager.load(logfile, initial_msgs=promptmsgs, show_hidden=show_hidden)
168+
# get initial system prompt
169+
prompt_msgs = [get_prompt(prompt_system)]
180170

181-
# if stdin is not a tty, we're getting piped input
171+
# if stdin is not a tty, we're getting piped input, which we should include in the prompt
182172
if not sys.stdin.isatty():
183173
# fetch prompt from stdin
184174
prompt_stdin = _read_stdin()
185175
if prompt_stdin:
186-
promptmsgs += [Message("system", prompt_stdin)]
176+
prompt_msgs += [Message("system", f"```stdin\n{prompt_stdin}\n```")]
187177

188178
# Attempt to switch to interactive mode
189179
sys.stdin.close()
190180
sys.stdin = open("/dev/tty")
191181

182+
# we need to run this before checking stdin, since the interactive doesn't work with the switch back to interactive mode
183+
logfile = get_logfile(
184+
name, interactive=(not prompts and interactive) and sys.stdin.isatty()
185+
)
186+
print(f"Using logdir {logfile.parent}")
187+
log = LogManager.load(logfile, initial_msgs=prompt_msgs, show_hidden=show_hidden)
188+
192189
# print log
193190
log.print()
194191
print("--- ^^^ past messages ^^^ ---")

gptme/logmanager.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from .constants import CMDFIX, LOGSDIR
1111
from .message import Message, print_msg
12-
from .prompts import initial_prompt
12+
from .prompts import get_prompt
1313
from .tools.reduce import limit_log, reduce_log
1414
from .util import len_tokens
1515

@@ -111,7 +111,7 @@ def prepare_messages(self) -> list[Message]:
111111
def load(
112112
cls,
113113
logfile: PathLike,
114-
initial_msgs: list[Message] = list(initial_prompt()),
114+
initial_msgs: list[Message] = [get_prompt()],
115115
**kwargs,
116116
) -> "LogManager":
117117
"""Loads a conversation log."""
@@ -146,13 +146,9 @@ def get_last_code_block(
146146
msgs = msgs[-history:]
147147

148148
for msg in msgs[::-1]:
149-
# check if message contains a code block
150-
backtick_count = msg.content.count("```")
151-
if backtick_count >= 2:
152-
if content:
153-
return msg.content.split("```")[-2].split("\n", 1)[-1]
154-
else:
155-
return msg.content
149+
codeblocks = msg.get_codeblocks(content=content)
150+
if codeblocks:
151+
return codeblocks[-1]
156152
return None
157153

158154
def rename(self, name: str, keep_date=False) -> None:

gptme/message.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,35 @@ def __repr__(self):
5858
content = textwrap.shorten(self.content, 20, placeholder="...")
5959
return f"<Message role={self.role} content={content}>"
6060

61+
def get_codeblocks(self, content=False) -> list[str]:
62+
"""
63+
Get all codeblocks.
64+
If `content` set, return the content of the code block, else return the whole message.
65+
"""
66+
codeblocks = []
67+
content_str = self.content
68+
# prepend newline to make sure we get the first codeblock
69+
if not content_str.startswith("\n"):
70+
content_str = "\n" + content_str
71+
72+
# check if message contains a code block
73+
backtick_count = content_str.count("\n```")
74+
if backtick_count < 2:
75+
return []
76+
for i in range(1, backtick_count, 2):
77+
codeblock_str = content_str.split("\n```")[i]
78+
# get codeblock language or filename from first line
79+
lang_or_fn = codeblock_str.split("\n")[0]
80+
codeblock_str = "\n".join(codeblock_str.split("\n")[1:])
81+
82+
if content:
83+
codeblocks.append(codeblock_str)
84+
else:
85+
full_codeblock = f"```{lang_or_fn}\n{codeblock_str}\n```"
86+
codeblocks.append(full_codeblock)
87+
88+
return codeblocks
89+
6190

6291
def format_msgs(
6392
msgs: list[Message],

0 commit comments

Comments
 (0)