Skip to content

Commit

Permalink
Fix _update_topics_table() and add new tests
Browse files Browse the repository at this point in the history
- Refactor _update_topics_table method in store.py to fix KeyError issue
- Add test_select_recent_with_topic and test_select_recent_with_nested_topic in test_store.py
  • Loading branch information
basicthinker committed Jun 5, 2023
1 parent ae3bb1d commit 0dec1fb
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 14 deletions.
21 changes: 7 additions & 14 deletions devchat/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,11 @@ def _initialize_topics_table(self):
})

def _update_topics_table(self, prompt: Prompt):
if self._graph.in_degree(prompt):
if self._graph.in_degree(prompt.hash):
logger.error("Prompt %s not a leaf to update topics table", prompt.hash)

if prompt.parent:
print(self._topics_table.all())
topic = next((topic for topic in self._topics_table.all()
if prompt.parent in topic['leaves']), None)
if topic:
Expand Down Expand Up @@ -116,22 +117,14 @@ def store_prompt(self, prompt: Prompt):
prompt.parent, prompt.hash)
else:
self._graph.add_edge(prompt.hash, prompt.parent)
self._update_topics_table(prompt)
else:
self._topics_table.insert({
'ancestor_hash': prompt.hash,
'descendant_hash': prompt.hash,
'latest_time': prompt.timestamp,
'title': None,
'hidden': False
})

self._update_topics_table(prompt)

for reference_hash in prompt.references:
if reference_hash not in self._graph:
logger.warning("Reference %s not found while Prompt %s is stored to graph store.",
reference_hash, prompt.hash)
else:
self._graph.add_edge(prompt.hash, reference_hash)
logger.error("Reference %s not found while Prompt %s is stored to graph store.",
reference_hash, prompt.hash)

nx.write_graphml(self._graph, self._graph_path)

def get_prompt(self, prompt_hash: str) -> Prompt:
Expand Down
143 changes: 143 additions & 0 deletions tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,146 @@ def test_select_recent(tmp_path):
assert len(recent_prompts) == 3
for index, prompt in enumerate(recent_prompts):
assert prompt.hash == hashes[4 - index]


def test_select_recent_with_topic(tmp_path):
config = OpenAIChatConfig(model="gpt-3.5-turbo")
chat = OpenAIChat(config)
store = Store(tmp_path / "store.graphml", chat)

# Create and store a root prompt
root_prompt = chat.init_prompt("Root question")
root_response_str = '''{
"id": "chatcmpl-root",
"object": "chat.completion",
"created": 1677649400,
"model": "gpt-3.5-turbo-0301",
"usage": {"prompt_tokens": 56, "completion_tokens": 31, "total_tokens": 87},
"choices": [
{
"message": {
"role": "assistant",
"content": "Root answer"
},
"finish_reason": "stop",
"index": 0
}
]
}'''
root_prompt.set_response(root_response_str)
store.store_prompt(root_prompt)

# Create and store 3 child prompts for the root prompt
child_hashes = []
for index in range(3):
child_prompt = chat.init_prompt(f"Child question {index}")
child_prompt.parent = root_prompt.hash
child_response_str = f'''{{
"id": "chatcmpl-child{index}",
"object": "chat.completion",
"created": 167764940{index + 1},
"model": "gpt-3.5-turbo-0301",
"usage": {{"prompt_tokens": 56, "completion_tokens": 31, "total_tokens": 87}},
"choices": [
{{
"message": {{
"role": "assistant",
"content": "Child answer {index}"
}},
"finish_reason": "stop",
"index": 0
}}
]
}}'''
child_prompt.set_response(child_response_str)
store.store_prompt(child_prompt)
child_hashes.append(child_prompt.hash)

# Test selecting recent prompts within the topic
recent_prompts = store.select_recent(0, 2, topic=root_prompt.hash)
assert len(recent_prompts) == 2
for index, prompt in enumerate(recent_prompts):
assert prompt.hash == child_hashes[2 - index]


def test_select_recent_with_nested_topic(tmp_path):
config = OpenAIChatConfig(model="gpt-3.5-turbo")
chat = OpenAIChat(config)
store = Store(tmp_path / "store.graphml", chat)

# Create and store a root prompt
root_prompt = chat.init_prompt("Root question")
root_response_str = '''{
"id": "chatcmpl-root",
"object": "chat.completion",
"created": 1677649400,
"model": "gpt-3.5-turbo-0301",
"usage": {"prompt_tokens": 56, "completion_tokens": 31, "total_tokens": 87},
"choices": [
{
"message": {
"role": "assistant",
"content": "Root answer"
},
"finish_reason": "stop",
"index": 0
}
]
}'''
root_prompt.set_response(root_response_str)
store.store_prompt(root_prompt)

# Create and store a child prompt for the root prompt
child_prompt = chat.init_prompt("Child question")
child_prompt.parent = root_prompt.hash
child_response_str = '''{
"id": "chatcmpl-child",
"object": "chat.completion",
"created": 1677649401,
"model": "gpt-3.5-turbo-0301",
"usage": {"prompt_tokens": 56, "completion_tokens": 31, "total_tokens": 87},
"choices": [
{
"message": {
"role": "assistant",
"content": "Child answer"
},
"finish_reason": "stop",
"index": 0
}
]
}'''
child_prompt.set_response(child_response_str)
store.store_prompt(child_prompt)

# Create and store 2 grandchild prompts for the child prompt
grandchild_hashes = []
for index in range(2):
grandchild_prompt = chat.init_prompt(f"Grandchild question {index}")
grandchild_prompt.parent = child_prompt.hash
grandchild_response_str = f'''{{
"id": "chatcmpl-grandchild{index}",
"object": "chat.completion",
"created": 167764940{index + 2},
"model": "gpt-3.5-turbo-0301",
"usage": {{"prompt_tokens": 56, "completion_tokens": 31, "total_tokens": 87}},
"choices": [
{{
"message": {{
"role": "assistant",
"content": "Grandchild answer {index}"
}},
"finish_reason": "stop",
"index": 0
}}
]
}}'''
grandchild_prompt.set_response(grandchild_response_str)
store.store_prompt(grandchild_prompt)
grandchild_hashes.append(grandchild_prompt.hash)

# Test selecting recent prompts within the nested topic
recent_prompts = store.select_recent(1, 3, topic=root_prompt.hash)
assert len(recent_prompts) == 2
assert recent_prompts[0].hash == grandchild_hashes[0]
assert recent_prompts[1].hash == child_prompt.hash

0 comments on commit 0dec1fb

Please sign in to comment.