diff --git a/devchat/openai/openai_prompt.py b/devchat/openai/openai_prompt.py index 28b18e73..a30adb7c 100644 --- a/devchat/openai/openai_prompt.py +++ b/devchat/openai/openai_prompt.py @@ -127,7 +127,7 @@ def _prepend_history(self, message_type: str, message: Message, def prepend_history(self, prompt: 'OpenAIPrompt', token_limit: int = math.inf) -> bool: # Prepend the first response and the request of the prompt - if not self._prepend_history(Message.CHAT, prompt.response[0], token_limit): + if not self._prepend_history(Message.CHAT, prompt.responses[0], token_limit): return False if not self._prepend_history(Message.CHAT, prompt.request, token_limit): return False @@ -163,10 +163,10 @@ def set_response(self, response_str: str): for choice in response_data['choices']: index = choice['index'] - if index >= len(self.response): - self.response.extend([None] * (index - len(self.response) + 1)) - self.response[index] = OpenAIMessage(**choice['message'], - finish_reason=choice['finish_reason']) + if index >= len(self.responses): + self.responses.extend([None] * (index - len(self.responses) + 1)) + self.responses[index] = OpenAIMessage(**choice['message'], + finish_reason=choice['finish_reason']) def append_response(self, delta_str: str) -> str: """ @@ -189,33 +189,33 @@ def append_response(self, delta_str: str) -> str: index = choice['index'] finish_reason = choice['finish_reason'] - if index >= len(self.response): - self.response.extend([None] * (index - len(self.response) + 1)) + if index >= len(self.responses): + self.responses.extend([None] * (index - len(self.responses) + 1)) - if not self.response[index]: - self.response[index] = OpenAIMessage(**delta) + if not self.responses[index]: + self.responses[index] = OpenAIMessage(**delta) if index == 0: delta_content = self.formatted_header() - delta_content += self.response[0].content if self.response[0].content else '' + delta_content += self.responses[0].content if self.responses[0].content else '' else: if index == 0: - delta_content = self.response[0].stream_from_dict(delta) + delta_content = self.responses[0].stream_from_dict(delta) else: - self.response[index].stream_from_dict(delta) + self.responses[index].stream_from_dict(delta) if 'function_call' in delta: if 'name' in delta['function_call']: - self.response[index].function_call['name'] = \ - self.response[index].function_call.get('name', '') + \ + self.responses[index].function_call['name'] = \ + self.responses[index].function_call.get('name', '') + \ delta['function_call']['name'] if 'arguments' in delta['function_call']: - self.response[index].function_call['arguments'] = \ - self.response[index].function_call.get('arguments', '') + \ + self.responses[index].function_call['arguments'] = \ + self.responses[index].function_call.get('arguments', '') + \ delta['function_call']['arguments'] if finish_reason: if finish_reason == 'function_call': - delta_content += self.response[index].function_call_to_json() + delta_content += self.responses[index].function_call_to_json() delta_content += f"\n\nfinish_reason: {finish_reason}" return delta_content @@ -225,7 +225,7 @@ def _count_response_tokens(self) -> int: return self._response_tokens total = 0 - for response_message in self.response: + for response_message in self.responses: total += response_tokens(response_message.content, self.model) self._response_tokens = total return total diff --git a/devchat/prompt.py b/devchat/prompt.py index 0a6e7a4f..26838858 100644 --- a/devchat/prompt.py +++ b/devchat/prompt.py @@ -36,7 +36,7 @@ class Prompt(ABC): Message.INSTRUCT: [], 'request': None, Message.CONTEXT: [], - 'response': [] + 'responses': [] }) _history_messages: Dict[str, Message] = field(default_factory=lambda: { Message.CONTEXT: [], @@ -56,9 +56,9 @@ def _check_complete(self) -> bool: Returns: bool: Whether the prompt is complete. """ - if not self.request or not self.response: + if not self.request or not self.responses: logger.warning("Incomplete prompt: request = %s, response = %s", - self.request, self.response) + self.request, self.responses) return False if not self._request_tokens or not self._response_tokens: @@ -81,8 +81,8 @@ def request(self) -> Message: return self._new_messages['request'] @property - def response(self) -> List[Message]: - return self._new_messages['response'] + def responses(self) -> List[Message]: + return self._new_messages['responses'] @property def request_tokens(self) -> int: @@ -193,7 +193,7 @@ def finalize_hash(self) -> str: self._count_response_tokens() data = asdict(self) - assert data.pop('_hash') is None + data.pop('_hash') string = str(tuple(sorted(data.items()))) self._hash = hashlib.sha256(string.encode('utf-8')).hexdigest() return self._hash @@ -219,18 +219,18 @@ def formatted_response(self, index: int) -> str: """ formatted_str = self.formatted_header() - if index >= len(self.response) or not self.response[index]: + if index >= len(self.responses) or not self.responses[index]: logger.error("Response index %d is incomplete to format: request = %s, response = %s", - index, self.request, self.response) + index, self.request, self.responses) return None - if self.response[index].content: - formatted_str += self.response[index].content + if self.responses[index].content: + formatted_str += self.responses[index].content formatted_str += "\n\n" - if self.response[index].finish_reason == 'function_call': - formatted_str += self.response[index].function_call_to_json() - formatted_str += f"\n\nfinish_reason: {self.response[index].finish_reason}" + "\n\n" + if self.responses[index].finish_reason == 'function_call': + formatted_str += self.responses[index].function_call_to_json() + formatted_str += f"\n\nfinish_reason: {self.responses[index].finish_reason}" + "\n\n" formatted_str += f"prompt {self.hash}" @@ -238,19 +238,22 @@ def formatted_response(self, index: int) -> str: def shortlog(self) -> List[dict]: """Generate a shortlog of the prompt.""" - if not self.request or not self.response: + if not self.request or not self.responses: raise ValueError("Prompt is incomplete for shortlog.") - logs = [] - for message in self.response: - shortlog_data = { - "user": user_id(self.user_name, self.user_email)[0], - "date": self._timestamp, - "context": [msg.to_dict() for msg in self.new_context], - "request": self.request.content, - "response": ((message.content if message.content else "") - + message.function_call_to_json()), - "hash": self.hash, - "parent": self.parent - } - logs.append(shortlog_data) - return logs + + responses = [] + for message in self.responses: + responses += ((message.content if message.content else "") + + message.function_call_to_json()) + + return { + "user": user_id(self.user_name, self.user_email)[0], + "date": self._timestamp, + "context": [msg.to_dict() for msg in self.new_context], + "request": self.request.content, + "responses": responses, + "request_tokens": self._request_tokens, + "response_tokens": self._response_tokens, + "hash": self.hash, + "parent": self.parent + } diff --git a/tests/test_cli_prompt.py b/tests/test_cli_prompt.py index d5016ff7..e64398eb 100644 --- a/tests/test_cli_prompt.py +++ b/tests/test_cli_prompt.py @@ -137,9 +137,9 @@ def test_prompt_log_with_functions(git_repo, functions_file): # pylint: disable result_json = json.loads(result.output) assert result.exit_code == 0 - assert result_json[0][0]['request'] == 'What is the weather like in Boston?' - assert result_json[0][0]['response'].find("```command") >= 0 - assert result_json[0][0]['response'].find("get_current_weather") >= 0 + assert result_json[0]['request'] == 'What is the weather like in Boston?' + assert result_json[0]['responses'][0].find("```command") >= 0 + assert result_json[0]['responses'][0].find("get_current_weather") >= 0 def test_prompt_log_compatibility(): diff --git a/tests/test_openai_prompt.py b/tests/test_openai_prompt.py index a8ac6bb4..b8327769 100644 --- a/tests/test_openai_prompt.py +++ b/tests/test_openai_prompt.py @@ -38,9 +38,9 @@ def test_prompt_init_and_set_response(): assert prompt.timestamp == 1677649420 assert prompt.request_tokens == 56 assert prompt.response_tokens == 31 - assert len(prompt.response) == 1 - assert prompt.response[0].role == "assistant" - assert prompt.response[0].content == "The 2020 World Series was played in Arlington, Texas." + assert len(prompt.responses) == 1 + assert prompt.responses[0].role == "assistant" + assert prompt.responses[0].content == "The 2020 World Series was played in Arlington, Texas." def test_prompt_model_mismatch(): @@ -103,8 +103,8 @@ def test_append_response(responses): OpenAIMessage(role='assistant', content='Tomorrow!') ] - assert len(prompt.response) == len(expected_messages) - for index, message in enumerate(prompt.response): + assert len(prompt.responses) == len(expected_messages) + for index, message in enumerate(prompt.responses): assert message.role == expected_messages[index].role assert message.content == expected_messages[index].content