Skip to content

Commit

Permalink
Add devchat log --insert
Browse files Browse the repository at this point in the history
  • Loading branch information
basicthinker committed Oct 9, 2023
1 parent 28a89fc commit d6a1765
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 16 deletions.
40 changes: 34 additions & 6 deletions devchat/_cli/log.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
import json
import sys
from typing import Optional, List, Dict
from pydantic import BaseModel
import rich_click as click
from devchat.openai.openai_chat import OpenAIChat, OpenAIChatConfig
from devchat.openai.openai_chat import OpenAIChat, OpenAIChatConfig, OpenAIPrompt
from devchat.store import Store
from devchat.utils import get_logger
from devchat.utils import get_logger, get_user_info
from devchat._cli.utils import handle_errors, init_dir, get_model_config


class PromptData(BaseModel):
model: str
messages: List[Dict]
parent: Optional[str] = None
references: Optional[List[str]] = []
timestamp: int
request_tokens: int
response_tokens: int


logger = get_logger(__name__)


Expand All @@ -14,13 +27,15 @@
@click.option('-n', '--max-count', default=1, help='Limit the number of commits to output.')
@click.option('-t', '--topic', 'topic_root', default=None,
help='Hash of the root prompt of the topic to select prompts from.')
@click.option('--delete', default=None, help='Delete a leaf prompt from the log.')
def log(skip, max_count, topic_root, delete):
@click.option('--insert', default=None, help='JSON string of the prompt to insert into the log.')
@click.option('--delete', default=None, help='Hash of the leaf prompt to delete from the log.')
def log(skip, max_count, topic_root, insert, delete):
"""
Manage the prompt history.
"""
if delete and (skip != 0 or max_count != 1 or topic_root is not None):
click.echo("Error: The --delete option cannot be used with other options.", err=True)
if (insert or delete) and (skip != 0 or max_count != 1 or topic_root is not None):
click.echo("Error: The --insert or --delete option cannot be used with other options.",
err=True)
sys.exit(1)

repo_chat_dir, user_chat_dir = init_dir()
Expand All @@ -39,6 +54,19 @@ def log(skip, max_count, topic_root, delete):
else:
click.echo(f"Failed to delete prompt {delete}.")
else:
if insert:
prompt_data = PromptData(**json.loads(insert))
user, email = get_user_info()
prompt = OpenAIPrompt(prompt_data.model, user, email)
prompt.model = prompt_data.model
prompt.input_messages(prompt_data.messages)
prompt.parent = prompt_data.parent
prompt.references = prompt_data.references
prompt._timestamp = prompt_data.timestamp
prompt._request_tokens = prompt_data.request_tokens
prompt._response_tokens = prompt_data.response_tokens
store.store_prompt(prompt)

recent_prompts = store.select_prompts(skip, skip + max_count, topic_root)
logs = []
for record in recent_prompts:
Expand Down
16 changes: 10 additions & 6 deletions devchat/openai/openai_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,15 @@ def input_messages(self, messages: List[dict]):
logger.warning("Invalid new context message: %s", message)

if not self.request:
last_user_message = self._history_messages[Message.CHAT].pop()
if last_user_message.role in ("user", "function"):
self._new_messages["request"] = last_user_message
else:
logger.warning("Invalid user request: %s", last_user_message)
while True:
last_message = self._history_messages[Message.CHAT].pop()
if last_message.role in ("user", "function"):
self._new_messages["request"] = last_message
break
if last_message.role == "assistant":
self._new_messages["responses"].append(last_message)
continue
self._history_messages[Message.CHAT].append(last_message)

def append_new(self, message_type: str, content: str,
available_tokens: int = sys.maxsize) -> bool:
Expand Down Expand Up @@ -232,7 +236,7 @@ def _validate_model(self, response_data: dict):
f"got '{response_data['model']}'")

def _timestamp_from_dict(self, response_data: dict):
if self._timestamp is None:
if not self._timestamp:
self._timestamp = response_data['created']
elif self._timestamp != response_data['created']:
raise ValueError(f"Time mismatch: expected {self._timestamp}, "
Expand Down
12 changes: 8 additions & 4 deletions devchat/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ class Prompt(ABC):
})
parent: str = None
references: List[str] = field(default_factory=list)
_timestamp: int = None
_timestamp: int = 0
_request_tokens: int = 0
_response_tokens: int = 0
_response_reasons: List[str] = field(default_factory=list)
_hash: str = None

def _complete_for_hash(self) -> bool:
def _complete_for_hashing(self) -> bool:
"""
Check if the prompt is complete for hashing.
Expand All @@ -62,6 +62,10 @@ def _complete_for_hash(self) -> bool:
self.request, self.responses)
return False

if not self.timestamp:
logger.warning("Prompt lacks timestamp for hashing: %s", self.request)
return False

if not self._response_tokens:
return False

Expand Down Expand Up @@ -114,7 +118,7 @@ def messages(self) -> List[dict]:
def input_messages(self, messages: List[dict]):
"""
Input the messages from the chat API to new and history messages.
The message list should be generated by the `messages` property.
The message list must follow the convention of the `messages` property.
Args:
messages (List[dict]): The messages from the chat API.
Expand Down Expand Up @@ -185,7 +189,7 @@ def finalize_hash(self) -> str:
Returns:
str: The hash of the prompt. None if the prompt is incomplete.
"""
if not self._complete_for_hash():
if not self._complete_for_hashing():
self._hash = None

if self._hash:
Expand Down
81 changes: 81 additions & 0 deletions tests/test_cli_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,84 @@ def test_tokens_with_log(git_repo): # pylint: disable=W0613
logs = json.loads(result.output)
assert _within_range(logs[1]["request_tokens"], logs[0]["request_tokens"])
assert _within_range(logs[1]["response_tokens"], logs[0]["response_tokens"])


def test_log_insert(git_repo): # pylint: disable=W0613
chat1 = """{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "This is Topic 1. Reply the topic number."
},
{
"role": "assistant",
"content": "Topic 1"
}
],
"timestamp": 1610000000,
"request_tokens": 100,
"response_tokens": 100
}"""
result = runner.invoke(
main,
['log', '--insert', chat1]
)
prompt1 = json.loads(result.output)[0]

chat2 = """{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "This is Topic 2. Reply the topic number."
},
{
"role": "assistant",
"content": "Topic 2"
}
],
"timestamp": 1620000000,
"request_tokens": 200,
"response_tokens": 200
}"""
result = runner.invoke(
main,
['log', '--insert', chat2]
)
prompt2 = json.loads(result.output)[0]

chat3 = """{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "Let's continue with Topic 1."
},
{
"role": "assistant",
"content": "Sure!"
}
],
"parent": "%s",
"timestamp": 1630000000,
"request_tokens": 300,
"response_tokens": 300
}""" % prompt1['hash']
result = runner.invoke(
main,
['log', '--insert', chat3]
)
prompt3 = json.loads(result.output)[0]
assert prompt3['parent'] == prompt1['hash']

result = runner.invoke(main, ['log', '-n', 3])
logs = json.loads(result.output)
assert logs[0]['hash'] == prompt3['hash']
assert logs[1]['hash'] == prompt2['hash']
assert logs[2]['hash'] == prompt1['hash']

result = runner.invoke(main, ['topic', '--list'])
topics = json.loads(result.output)
assert topics[0]['root_prompt']['hash'] == prompt1['hash']
assert topics[1]['root_prompt']['hash'] == prompt2['hash']

0 comments on commit d6a1765

Please sign in to comment.