Skip to content

Commit

Permalink
Improve response format in Assistant and Prompt
Browse files Browse the repository at this point in the history
- Add _response_reasons field in Prompt to store finish reasons.
- Update append_response in OpenAIPrompt to store finish reason.
- Replace formatted_response with formatted_footer and full_response.
- Add timestamp and hash checks in Prompt's formatted_header and footer.
  • Loading branch information
basicthinker committed Jul 21, 2023
1 parent 8c58163 commit 348a486
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 39 deletions.
18 changes: 11 additions & 7 deletions devchat/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,20 @@ def iterate_response(self) -> Iterator[str]:
Iterator[str]: An iterator over response strings from the chat API.
"""
if self._chat.config.stream:
response_iterator = self._chat.stream_response(self._prompt)
for chunk in response_iterator:
yield self._prompt.append_response(str(chunk))
first_chunk = True
for chunk in self._chat.stream_response(self._prompt):
delta = self._prompt.append_response(str(chunk))
if first_chunk:
first_chunk = False
yield self._prompt.formatted_header()
yield delta
self._store.store_prompt(self._prompt)
yield f'\n\nprompt {self._prompt.hash}\n'
yield self._prompt.formatted_footer(0) + '\n'
for index in range(1, len(self._prompt.responses)):
yield self._prompt.formatted_response(index) + '\n'
yield self._prompt.formatted_full_response(index) + '\n'
else:
response_str = str(self._chat.complete_response(self._prompt))
response_str = self._chat.complete_response(self._prompt)
self._prompt.set_response(response_str)
self._store.store_prompt(self._prompt)
for index in range(len(self._prompt.responses)):
yield self._prompt.formatted_response(index) + '\n'
yield self._prompt.formatted_full_response(index) + '\n'
2 changes: 0 additions & 2 deletions devchat/openai/openai_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ class OpenAIMessage(Message):
role: str = None
name: Optional[str] = None
function_call: Dict[str, str] = field(default_factory=dict)
finish_reason: str = None

def __post_init__(self):
if not self._validate_role():
Expand All @@ -23,7 +22,6 @@ def __post_init__(self):

def to_dict(self) -> dict:
state = asdict(self)
del state['finish_reason']
if state['name'] is None:
del state['name']
if not state['function_call'] or len(state['function_call'].keys()) == 0:
Expand Down
15 changes: 7 additions & 8 deletions devchat/openai/openai_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,10 @@ def set_response(self, response_str: str):
index = choice['index']
if index >= len(self.responses):
self.responses.extend([None] * (index - len(self.responses) + 1))
self.responses[index] = OpenAIMessage(**choice['message'],
finish_reason=choice['finish_reason'])
self._response_reasons.extend([None] * (index - len(self._response_reasons) + 1))
self.responses[index] = OpenAIMessage(**choice['message'])
if choice['finish_reason']:
self._response_reasons[index] = choice['finish_reason']

def append_response(self, delta_str: str) -> str:
"""
Expand All @@ -191,12 +193,12 @@ def append_response(self, delta_str: str) -> str:

if index >= len(self.responses):
self.responses.extend([None] * (index - len(self.responses) + 1))
self._response_reasons.extend([None] * (index - len(self._response_reasons) + 1))

if not self.responses[index]:
self.responses[index] = OpenAIMessage(**delta)
if index == 0:
delta_content = self.formatted_header()
delta_content += self.responses[0].content if self.responses[0].content else ''
delta_content = self.responses[0].content if self.responses[0].content else ''
else:
if index == 0:
delta_content = self.responses[0].stream_from_dict(delta)
Expand All @@ -214,10 +216,7 @@ def append_response(self, delta_str: str) -> str:
delta['function_call']['arguments']

if finish_reason:
if finish_reason == 'function_call':
delta_content += self.responses[index].function_call_to_json()
delta_content += f"\n\nfinish_reason: {finish_reason}"

self._response_reasons[index] = finish_reason
return delta_content

def _count_response_tokens(self) -> int:
Expand Down
56 changes: 36 additions & 20 deletions devchat/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class Prompt(ABC):
_timestamp: int = None
_request_tokens: int = 0
_response_tokens: int = 0
_response_reasons: List[str] = field(default_factory=list)
_hash: str = None

def _check_complete(self) -> bool:
Expand All @@ -56,14 +57,12 @@ def _check_complete(self) -> bool:
Returns:
bool: Whether the prompt is complete.
"""
if not self.request or not self.responses:
logger.warning("Incomplete prompt: request = %s, response = %s",
self.request, self.responses)
if not self.request or not self._request_tokens or not self.responses:
logger.warning("Incomplete prompt: request = %s (%d), response = %s",
self.request, self._request_tokens, self.responses)
return False

if not self._request_tokens or not self._response_tokens:
logger.warning("Incomplete prompt: request_tokens = %d, response_tokens = %d",
self._request_tokens, self._response_tokens)
if not self._response_tokens:
return False

return True
Expand Down Expand Up @@ -202,39 +201,56 @@ def formatted_header(self) -> str:
"""Formatted string header of the prompt."""
formatted_str = f"User: {user_id(self.user_name, self.user_email)[0]}\n"

if not self._timestamp:
raise ValueError(f"Prompt lacks timestamp for formatting header: {self.request}")

local_time = unix_to_local_datetime(self._timestamp)
formatted_str += f"Date: {local_time.strftime('%a %b %d %H:%M:%S %Y %z')}\n\n"

return formatted_str

def formatted_response(self, index: int) -> str:
def formatted_footer(self, index: int) -> str:
"""Formatted string footer of the prompt."""
if not self.hash:
raise ValueError(f"Prompt lacks hash for formatting footer: {self.request}")

note = None
formatted_str = "\n\n"
reason = self._response_reasons[index]
if reason == 'length':
note = "Incomplete model output due to max_tokens parameter or token limit"
elif reason == 'function_call':
formatted_str += self.responses[index].function_call_to_json() + "\n\n"
note = "The model decided to call a function"
elif reason == 'content_filter':
note = "Omitted content due to a flag from our content filters"

if note:
formatted_str += f"Note: {note} (finish_reason: {reason})\n\n"

return formatted_str + f"prompt {self.hash}"

def formatted_full_response(self, index: int) -> str:
"""
Formatted response of the prompt.
Formatted full response of the prompt.
Args:
index (int): The index of the response to format.
Returns:
str: The formatted response string. None if the response is incomplete.
str: The formatted response string. None if the response is invalid.
"""
formatted_str = self.formatted_header()

if index >= len(self.responses) or not self.responses[index]:
logger.error("Response index %d is incomplete to format: request = %s, response = %s",
logger.error("Response index %d is invalid to format: request = %s, response = %s",
index, self.request, self.responses)
return None

formatted_str = self.formatted_header()

if self.responses[index].content:
formatted_str += self.responses[index].content
formatted_str += "\n\n"

if self.responses[index].finish_reason == 'function_call':
formatted_str += self.responses[index].function_call_to_json()
formatted_str += f"\n\nfinish_reason: {self.responses[index].finish_reason}" + "\n\n"

formatted_str += f"prompt {self.hash}"

return formatted_str
return formatted_str + self.formatted_footer(index)

def shortlog(self) -> List[dict]:
"""Generate a shortlog of the prompt."""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ def test_prompt_with_functions(git_repo, functions_file): # pylint: disable=W06
# call with -f option
result = runner.invoke(main, ['prompt', '-m', 'gpt-3.5-turbo', '-f', functions_file,
"What is the weather like in Boston?"])

content = get_content(result.output)
print(result.output)
assert result.exit_code == 0
content = get_content(result.output)
assert 'finish_reason: function_call' in content
assert '```command' in content
assert '"name": "get_current_weather"' in content
Expand Down

0 comments on commit 348a486

Please sign in to comment.