Skip to content

Commit

Permalink
Simplify topic management in Store class
Browse files Browse the repository at this point in the history
- Remove _roots and _leaves methods
- Refactor _initialize_topics_table method to use roots directly
- Update _update_topics_table method to handle topic updates more efficiently
- Modify select_topics method to return visible topics only
- Add test cases for topic selection and prompt selection within a topic
  • Loading branch information
basicthinker committed Jun 6, 2023
1 parent 6706dd1 commit d96528f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 36 deletions.
48 changes: 14 additions & 34 deletions devchat/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,28 +42,13 @@ def __init__(self, store_dir: str, chat: Chat):
if not self._topics_table or self._topics_table.all() == []:
self._initialize_topics_table()

def _roots(self) -> List[object]:
graph = self._graph
return [node for node in graph.nodes() if graph.out_degree(node) == 0]

def _leaves(self) -> List[object]:
graph = self._graph
return [node for node in graph.nodes() if graph.in_degree(node) == 0]

def _initialize_topics_table(self):
roots = self._roots()
all_leaves = self._leaves()
root_to_leaves = {root: [] for root in roots}

for leaf in all_leaves:
root = next(nx.descendants(self._graph, leaf).intersection(roots))
root_to_leaves[root].append(leaf)

for root, leaves in root_to_leaves.items():
latest_time = max(self._graph.nodes[leaf]['timestamp'] for leaf in leaves)
roots = [node for node in self._graph.nodes() if self._graph.out_degree(node) == 0]
for root in roots:
latest_time = max(self._graph.nodes[node]['timestamp'] for
node in nx.ancestors(self._graph, root))
self._topics_table.insert({
'root': root,
'leaves': leaves,
'latest_time': latest_time,
'title': None,
'hidden': False
Expand All @@ -74,21 +59,15 @@ def _update_topics_table(self, prompt: Prompt):
logger.error("Prompt %s not a leaf to update topics table", prompt.hash)

if prompt.parent:
print(self._topics_table.all())
topic = next((topic for topic in self._topics_table.all()
if prompt.parent in topic['leaves']), None)
if topic:
topic['leaves'].remove(prompt.parent)
topic['leaves'].append(prompt.hash)
topic['latest_time'] = max(topic['latest_time'], prompt.timestamp)
self._topics_table.update(topic, doc_ids=[topic.doc_id])
else:
logger.error("Parent %s of prompt %s not found in topic leaves",
prompt.parent, prompt.hash)
for topic in self._topics_table.all():
if prompt.parent == topic['root'] or \
prompt.parent in nx.ancestors(self._graph, topic['root']):
topic['latest_time'] = max(topic['latest_time'], prompt.timestamp)
self._topics_table.update(topic, doc_ids=[topic.doc_id])
break
else:
self._topics_table.insert({
'root': prompt.hash,
'leaves': [prompt.hash],
'latest_time': prompt.timestamp,
'title': None,
'hidden': False
Expand Down Expand Up @@ -190,10 +169,11 @@ def select_topics(self, start: int, end: int) -> List[Dict[str, Any]]:
Returns:
List[Dict[str, Any]]: A list of dictionaries containing root prompts
with latest_time, title, and hidden fields.
with latest_time, and title fields.
"""
topics = self._topics_table.all()
sorted_topics = sorted(topics, key=lambda x: x['latest_time'], reverse=True)
visible_topics = self._topics_table.search(
where('hidden') == False) # pylint: disable=C0121
sorted_topics = sorted(visible_topics, key=lambda x: x['latest_time'], reverse=True)

topics = []
for topic in sorted_topics[start:end]:
Expand Down
19 changes: 17 additions & 2 deletions tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,17 @@ def test_select_recent(tmp_path):
assert prompt.hash == hashes[4 - index]


def test_select_recent_with_topic(tmp_path):
def test_select_topics_no_topics(tmp_path):
config = OpenAIChatConfig(model="gpt-3.5-turbo")
chat = OpenAIChat(config)
store = Store(tmp_path / "store.graphml", chat)

# Test selecting topics when there are no topics
topics = store.select_topics(0, 5)
assert len(topics) == 0


def test_select_topics_and_prompts_with_single_root(tmp_path):
config = OpenAIChatConfig(model="gpt-3.5-turbo")
chat = OpenAIChat(config)
store = Store(tmp_path / "store.graphml", chat)
Expand Down Expand Up @@ -124,7 +134,12 @@ def test_select_recent_with_topic(tmp_path):
store.store_prompt(child_prompt)
child_hashes.append(child_prompt.hash)

# Test selecting recent prompts within the topic
# Test selecting topics
topics = store.select_topics(0, 5)
assert len(topics) == 1
assert topics[0]['root_prompt'].hash == root_prompt.hash

# Test selecting prompts within the topic
recent_prompts = store.select_prompts(0, 2, topic=root_prompt.hash)
assert len(recent_prompts) == 2
for index, prompt in enumerate(recent_prompts):
Expand Down

0 comments on commit d96528f

Please sign in to comment.