diff --git a/devchat/_cli/log.py b/devchat/_cli/log.py index 40809c09..009a4375 100644 --- a/devchat/_cli/log.py +++ b/devchat/_cli/log.py @@ -1,11 +1,24 @@ import json import sys +from typing import Optional, List, Dict +from pydantic import BaseModel import rich_click as click -from devchat.openai.openai_chat import OpenAIChat, OpenAIChatConfig +from devchat.openai.openai_chat import OpenAIChat, OpenAIChatConfig, OpenAIPrompt from devchat.store import Store -from devchat.utils import get_logger +from devchat.utils import get_logger, get_user_info from devchat._cli.utils import handle_errors, init_dir, get_model_config + +class PromptData(BaseModel): + model: str + messages: List[Dict] + parent: Optional[str] = None + references: Optional[List[str]] = [] + timestamp: int + request_tokens: int + response_tokens: int + + logger = get_logger(__name__) @@ -14,13 +27,15 @@ @click.option('-n', '--max-count', default=1, help='Limit the number of commits to output.') @click.option('-t', '--topic', 'topic_root', default=None, help='Hash of the root prompt of the topic to select prompts from.') -@click.option('--delete', default=None, help='Delete a leaf prompt from the log.') -def log(skip, max_count, topic_root, delete): +@click.option('--insert', default=None, help='JSON string of the prompt to insert into the log.') +@click.option('--delete', default=None, help='Hash of the leaf prompt to delete from the log.') +def log(skip, max_count, topic_root, insert, delete): """ Manage the prompt history. """ - if delete and (skip != 0 or max_count != 1 or topic_root is not None): - click.echo("Error: The --delete option cannot be used with other options.", err=True) + if (insert or delete) and (skip != 0 or max_count != 1 or topic_root is not None): + click.echo("Error: The --insert or --delete option cannot be used with other options.", + err=True) sys.exit(1) repo_chat_dir, user_chat_dir = init_dir() @@ -39,6 +54,19 @@ def log(skip, max_count, topic_root, delete): else: click.echo(f"Failed to delete prompt {delete}.") else: + if insert: + prompt_data = PromptData(**json.loads(insert)) + user, email = get_user_info() + prompt = OpenAIPrompt(prompt_data.model, user, email) + prompt.model = prompt_data.model + prompt.input_messages(prompt_data.messages) + prompt.parent = prompt_data.parent + prompt.references = prompt_data.references + prompt._timestamp = prompt_data.timestamp + prompt._request_tokens = prompt_data.request_tokens + prompt._response_tokens = prompt_data.response_tokens + store.store_prompt(prompt) + recent_prompts = store.select_prompts(skip, skip + max_count, topic_root) logs = [] for record in recent_prompts: diff --git a/devchat/openai/openai_prompt.py b/devchat/openai/openai_prompt.py index 45be3073..f9e42dd8 100644 --- a/devchat/openai/openai_prompt.py +++ b/devchat/openai/openai_prompt.py @@ -84,11 +84,15 @@ def input_messages(self, messages: List[dict]): logger.warning("Invalid new context message: %s", message) if not self.request: - last_user_message = self._history_messages[Message.CHAT].pop() - if last_user_message.role in ("user", "function"): - self._new_messages["request"] = last_user_message - else: - logger.warning("Invalid user request: %s", last_user_message) + while True: + last_message = self._history_messages[Message.CHAT].pop() + if last_message.role in ("user", "function"): + self._new_messages["request"] = last_message + break + if last_message.role == "assistant": + self._new_messages["responses"].append(last_message) + continue + self._history_messages[Message.CHAT].append(last_message) def append_new(self, message_type: str, content: str, available_tokens: int = sys.maxsize) -> bool: @@ -232,7 +236,7 @@ def _validate_model(self, response_data: dict): f"got '{response_data['model']}'") def _timestamp_from_dict(self, response_data: dict): - if self._timestamp is None: + if not self._timestamp: self._timestamp = response_data['created'] elif self._timestamp != response_data['created']: raise ValueError(f"Time mismatch: expected {self._timestamp}, " diff --git a/devchat/prompt.py b/devchat/prompt.py index 27b66098..8d9f833b 100644 --- a/devchat/prompt.py +++ b/devchat/prompt.py @@ -44,13 +44,13 @@ class Prompt(ABC): }) parent: str = None references: List[str] = field(default_factory=list) - _timestamp: int = None + _timestamp: int = 0 _request_tokens: int = 0 _response_tokens: int = 0 _response_reasons: List[str] = field(default_factory=list) _hash: str = None - def _complete_for_hash(self) -> bool: + def _complete_for_hashing(self) -> bool: """ Check if the prompt is complete for hashing. @@ -62,6 +62,10 @@ def _complete_for_hash(self) -> bool: self.request, self.responses) return False + if not self.timestamp: + logger.warning("Prompt lacks timestamp for hashing: %s", self.request) + return False + if not self._response_tokens: return False @@ -114,7 +118,7 @@ def messages(self) -> List[dict]: def input_messages(self, messages: List[dict]): """ Input the messages from the chat API to new and history messages. - The message list should be generated by the `messages` property. + The message list must follow the convention of the `messages` property. Args: messages (List[dict]): The messages from the chat API. @@ -185,7 +189,7 @@ def finalize_hash(self) -> str: Returns: str: The hash of the prompt. None if the prompt is incomplete. """ - if not self._complete_for_hash(): + if not self._complete_for_hashing(): self._hash = None if self._hash: diff --git a/tests/test_cli_log.py b/tests/test_cli_log.py index 2ef3b91c..56febe2f 100755 --- a/tests/test_cli_log.py +++ b/tests/test_cli_log.py @@ -90,3 +90,84 @@ def test_tokens_with_log(git_repo): # pylint: disable=W0613 logs = json.loads(result.output) assert _within_range(logs[1]["request_tokens"], logs[0]["request_tokens"]) assert _within_range(logs[1]["response_tokens"], logs[0]["response_tokens"]) + + +def test_log_insert(git_repo): # pylint: disable=W0613 + chat1 = """{ + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "This is Topic 1. Reply the topic number." + }, + { + "role": "assistant", + "content": "Topic 1" + } + ], + "timestamp": 1610000000, + "request_tokens": 100, + "response_tokens": 100 + }""" + result = runner.invoke( + main, + ['log', '--insert', chat1] + ) + prompt1 = json.loads(result.output)[0] + + chat2 = """{ + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "This is Topic 2. Reply the topic number." + }, + { + "role": "assistant", + "content": "Topic 2" + } + ], + "timestamp": 1620000000, + "request_tokens": 200, + "response_tokens": 200 + }""" + result = runner.invoke( + main, + ['log', '--insert', chat2] + ) + prompt2 = json.loads(result.output)[0] + + chat3 = """{ + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "Let's continue with Topic 1." + }, + { + "role": "assistant", + "content": "Sure!" + } + ], + "parent": "%s", + "timestamp": 1630000000, + "request_tokens": 300, + "response_tokens": 300 + }""" % prompt1['hash'] + result = runner.invoke( + main, + ['log', '--insert', chat3] + ) + prompt3 = json.loads(result.output)[0] + assert prompt3['parent'] == prompt1['hash'] + + result = runner.invoke(main, ['log', '-n', 3]) + logs = json.loads(result.output) + assert logs[0]['hash'] == prompt3['hash'] + assert logs[1]['hash'] == prompt2['hash'] + assert logs[2]['hash'] == prompt1['hash'] + + result = runner.invoke(main, ['topic', '--list']) + topics = json.loads(result.output) + assert topics[0]['root_prompt']['hash'] == prompt1['hash'] + assert topics[1]['root_prompt']['hash'] == prompt2['hash']