diff --git a/devchat/openai/openai_prompt.py b/devchat/openai/openai_prompt.py index b5fedf61..e1ed1b97 100644 --- a/devchat/openai/openai_prompt.py +++ b/devchat/openai/openai_prompt.py @@ -90,10 +90,11 @@ def set_response(self, response_str: str): self._request_tokens = response_data['usage']['prompt_tokens'] self._response_tokens = response_data['usage']['completion_tokens'] - self._new_messages['response'] = { - choice['index']: OpenAIMessage(**choice['message']) - for choice in response_data['choices'] - } + for choice in response_data['choices']: + index = choice['index'] + if index >= len(self.response): + self.response.extend([None] * (index - len(self.response) + 1)) + self.response[index] = OpenAIMessage(**choice['message']) self.set_hash() def append_response(self, delta_str: str) -> str: @@ -116,7 +117,10 @@ def append_response(self, delta_str: str) -> str: delta = choice['delta'] index = choice['index'] - if index not in self.response: + if index >= len(self.response): + self.response.extend([None] * (index - len(self.response) + 1)) + + if not self.response[index]: self.response[index] = OpenAIMessage(**delta) if index == 0: delta_content = self.formatted_header() diff --git a/devchat/prompt.py b/devchat/prompt.py index ee224d8b..72c733d9 100644 --- a/devchat/prompt.py +++ b/devchat/prompt.py @@ -36,7 +36,7 @@ class Prompt(ABC): Message.INSTRUCT: [], 'request': None, Message.CONTEXT: [], - 'response': {} + 'response': [] }) _history_messages: Dict[str, Message] = field(default_factory=lambda: { Message.CONTEXT: [], @@ -73,7 +73,7 @@ def request(self) -> Message: return self._new_messages['request'] @property - def response(self) -> Dict[int, Message]: + def response(self) -> List[Message]: return self._new_messages['response'] @property