Skip to content

Commit

Permalink
Rename 'response' to 'responses' in Prompt
Browse files Browse the repository at this point in the history
- Update 'response' in OpenAIPrompt and Prompt classes.
- Updated all instances of 'response' to 'responses' in the codebase.
- Adjusted the shortlog method in Prompt to accommodate the change.
- Updated tests to reflect the changes.
  • Loading branch information
basicthinker committed Jul 20, 2023
1 parent de7e6ef commit 9177cf0
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 54 deletions.
36 changes: 18 additions & 18 deletions devchat/openai/openai_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _prepend_history(self, message_type: str, message: Message,

def prepend_history(self, prompt: 'OpenAIPrompt', token_limit: int = math.inf) -> bool:
# Prepend the first response and the request of the prompt
if not self._prepend_history(Message.CHAT, prompt.response[0], token_limit):
if not self._prepend_history(Message.CHAT, prompt.responses[0], token_limit):
return False
if not self._prepend_history(Message.CHAT, prompt.request, token_limit):
return False
Expand Down Expand Up @@ -163,10 +163,10 @@ def set_response(self, response_str: str):

for choice in response_data['choices']:
index = choice['index']
if index >= len(self.response):
self.response.extend([None] * (index - len(self.response) + 1))
self.response[index] = OpenAIMessage(**choice['message'],
finish_reason=choice['finish_reason'])
if index >= len(self.responses):
self.responses.extend([None] * (index - len(self.responses) + 1))
self.responses[index] = OpenAIMessage(**choice['message'],
finish_reason=choice['finish_reason'])

def append_response(self, delta_str: str) -> str:
"""
Expand All @@ -189,33 +189,33 @@ def append_response(self, delta_str: str) -> str:
index = choice['index']
finish_reason = choice['finish_reason']

if index >= len(self.response):
self.response.extend([None] * (index - len(self.response) + 1))
if index >= len(self.responses):
self.responses.extend([None] * (index - len(self.responses) + 1))

if not self.response[index]:
self.response[index] = OpenAIMessage(**delta)
if not self.responses[index]:
self.responses[index] = OpenAIMessage(**delta)
if index == 0:
delta_content = self.formatted_header()
delta_content += self.response[0].content if self.response[0].content else ''
delta_content += self.responses[0].content if self.responses[0].content else ''
else:
if index == 0:
delta_content = self.response[0].stream_from_dict(delta)
delta_content = self.responses[0].stream_from_dict(delta)
else:
self.response[index].stream_from_dict(delta)
self.responses[index].stream_from_dict(delta)

if 'function_call' in delta:
if 'name' in delta['function_call']:
self.response[index].function_call['name'] = \
self.response[index].function_call.get('name', '') + \
self.responses[index].function_call['name'] = \
self.responses[index].function_call.get('name', '') + \
delta['function_call']['name']
if 'arguments' in delta['function_call']:
self.response[index].function_call['arguments'] = \
self.response[index].function_call.get('arguments', '') + \
self.responses[index].function_call['arguments'] = \
self.responses[index].function_call.get('arguments', '') + \
delta['function_call']['arguments']

if finish_reason:
if finish_reason == 'function_call':
delta_content += self.response[index].function_call_to_json()
delta_content += self.responses[index].function_call_to_json()
delta_content += f"\n\nfinish_reason: {finish_reason}"

return delta_content
Expand All @@ -225,7 +225,7 @@ def _count_response_tokens(self) -> int:
return self._response_tokens

total = 0
for response_message in self.response:
for response_message in self.responses:
total += response_tokens(response_message.content, self.model)
self._response_tokens = total
return total
Expand Down
59 changes: 31 additions & 28 deletions devchat/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class Prompt(ABC):
Message.INSTRUCT: [],
'request': None,
Message.CONTEXT: [],
'response': []
'responses': []
})
_history_messages: Dict[str, Message] = field(default_factory=lambda: {
Message.CONTEXT: [],
Expand All @@ -56,9 +56,9 @@ def _check_complete(self) -> bool:
Returns:
bool: Whether the prompt is complete.
"""
if not self.request or not self.response:
if not self.request or not self.responses:
logger.warning("Incomplete prompt: request = %s, response = %s",
self.request, self.response)
self.request, self.responses)
return False

if not self._request_tokens or not self._response_tokens:
Expand All @@ -81,8 +81,8 @@ def request(self) -> Message:
return self._new_messages['request']

@property
def response(self) -> List[Message]:
return self._new_messages['response']
def responses(self) -> List[Message]:
return self._new_messages['responses']

@property
def request_tokens(self) -> int:
Expand Down Expand Up @@ -193,7 +193,7 @@ def finalize_hash(self) -> str:
self._count_response_tokens()

data = asdict(self)
assert data.pop('_hash') is None
data.pop('_hash')
string = str(tuple(sorted(data.items())))
self._hash = hashlib.sha256(string.encode('utf-8')).hexdigest()
return self._hash
Expand All @@ -219,38 +219,41 @@ def formatted_response(self, index: int) -> str:
"""
formatted_str = self.formatted_header()

if index >= len(self.response) or not self.response[index]:
if index >= len(self.responses) or not self.responses[index]:
logger.error("Response index %d is incomplete to format: request = %s, response = %s",
index, self.request, self.response)
index, self.request, self.responses)
return None

if self.response[index].content:
formatted_str += self.response[index].content
if self.responses[index].content:
formatted_str += self.responses[index].content
formatted_str += "\n\n"

if self.response[index].finish_reason == 'function_call':
formatted_str += self.response[index].function_call_to_json()
formatted_str += f"\n\nfinish_reason: {self.response[index].finish_reason}" + "\n\n"
if self.responses[index].finish_reason == 'function_call':
formatted_str += self.responses[index].function_call_to_json()
formatted_str += f"\n\nfinish_reason: {self.responses[index].finish_reason}" + "\n\n"

formatted_str += f"prompt {self.hash}"

return formatted_str

def shortlog(self) -> List[dict]:
"""Generate a shortlog of the prompt."""
if not self.request or not self.response:
if not self.request or not self.responses:
raise ValueError("Prompt is incomplete for shortlog.")
logs = []
for message in self.response:
shortlog_data = {
"user": user_id(self.user_name, self.user_email)[0],
"date": self._timestamp,
"context": [msg.to_dict() for msg in self.new_context],
"request": self.request.content,
"response": ((message.content if message.content else "")
+ message.function_call_to_json()),
"hash": self.hash,
"parent": self.parent
}
logs.append(shortlog_data)
return logs

responses = []
for message in self.responses:
responses += ((message.content if message.content else "")
+ message.function_call_to_json())

return {
"user": user_id(self.user_name, self.user_email)[0],
"date": self._timestamp,
"context": [msg.to_dict() for msg in self.new_context],
"request": self.request.content,
"responses": responses,
"request_tokens": self._request_tokens,
"response_tokens": self._response_tokens,
"hash": self.hash,
"parent": self.parent
}
6 changes: 3 additions & 3 deletions tests/test_cli_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ def test_prompt_log_with_functions(git_repo, functions_file): # pylint: disable

result_json = json.loads(result.output)
assert result.exit_code == 0
assert result_json[0][0]['request'] == 'What is the weather like in Boston?'
assert result_json[0][0]['response'].find("```command") >= 0
assert result_json[0][0]['response'].find("get_current_weather") >= 0
assert result_json[0]['request'] == 'What is the weather like in Boston?'
assert result_json[0]['responses'][0].find("```command") >= 0
assert result_json[0]['responses'][0].find("get_current_weather") >= 0


def test_prompt_log_compatibility():
Expand Down
10 changes: 5 additions & 5 deletions tests/test_openai_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def test_prompt_init_and_set_response():
assert prompt.timestamp == 1677649420
assert prompt.request_tokens == 56
assert prompt.response_tokens == 31
assert len(prompt.response) == 1
assert prompt.response[0].role == "assistant"
assert prompt.response[0].content == "The 2020 World Series was played in Arlington, Texas."
assert len(prompt.responses) == 1
assert prompt.responses[0].role == "assistant"
assert prompt.responses[0].content == "The 2020 World Series was played in Arlington, Texas."


def test_prompt_model_mismatch():
Expand Down Expand Up @@ -103,8 +103,8 @@ def test_append_response(responses):
OpenAIMessage(role='assistant', content='Tomorrow!')
]

assert len(prompt.response) == len(expected_messages)
for index, message in enumerate(prompt.response):
assert len(prompt.responses) == len(expected_messages)
for index, message in enumerate(prompt.responses):
assert message.role == expected_messages[index].role
assert message.content == expected_messages[index].content

Expand Down

0 comments on commit 9177cf0

Please sign in to comment.