Skip to content

Commit

Permalink
Replace dict with list for storing responses in OpenAIPrompt
Browse files Browse the repository at this point in the history
- Changed response storage from a dictionary indexed by int to a list.
- Updated related code to handle list indexing and extending.
  • Loading branch information
basicthinker committed May 22, 2023
1 parent 134319b commit 9fa4e31
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
14 changes: 9 additions & 5 deletions devchat/openai/openai_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,11 @@ def set_response(self, response_str: str):
self._request_tokens = response_data['usage']['prompt_tokens']
self._response_tokens = response_data['usage']['completion_tokens']

self._new_messages['response'] = {
choice['index']: OpenAIMessage(**choice['message'])
for choice in response_data['choices']
}
for choice in response_data['choices']:
index = choice['index']
if index >= len(self.response):
self.response.extend([None] * (index - len(self.response) + 1))
self.response[index] = OpenAIMessage(**choice['message'])
self.set_hash()

def append_response(self, delta_str: str) -> str:
Expand All @@ -116,7 +117,10 @@ def append_response(self, delta_str: str) -> str:
delta = choice['delta']
index = choice['index']

if index not in self.response:
if index >= len(self.response):
self.response.extend([None] * (index - len(self.response) + 1))

if not self.response[index]:
self.response[index] = OpenAIMessage(**delta)
if index == 0:
delta_content = self.formatted_header()
Expand Down
4 changes: 2 additions & 2 deletions devchat/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class Prompt(ABC):
Message.INSTRUCT: [],
'request': None,
Message.CONTEXT: [],
'response': {}
'response': []
})
_history_messages: Dict[str, Message] = field(default_factory=lambda: {
Message.CONTEXT: [],
Expand Down Expand Up @@ -73,7 +73,7 @@ def request(self) -> Message:
return self._new_messages['request']

@property
def response(self) -> Dict[int, Message]:
def response(self) -> List[Message]:
return self._new_messages['response']

@property
Expand Down

0 comments on commit 9fa4e31

Please sign in to comment.