Skip to content

Commit

Permalink
Add topic management to Store class
Browse files Browse the repository at this point in the history
- Add '_topics_table' to manage topics
- Implement _initialize_topics_table(), and _update_topics_table()
- Update store_prompt() to handle topic updates
  • Loading branch information
basicthinker committed Jun 5, 2023
1 parent e57f850 commit 8e0d40a
Showing 1 changed file with 82 additions and 17 deletions.
99 changes: 82 additions & 17 deletions devchat/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from devchat.chat import Chat
from devchat.prompt import Prompt


logger = logging.getLogger(__name__)


Expand All @@ -18,7 +17,8 @@ def __init__(self, store_dir: str, chat: Chat):
Initializes a Store instance.
Args:
path (str): The folder to store the files containing the store.
store_dir (str): The folder to store the files containing the store.
chat (Chat): The Chat instance.
"""
store_dir = os.path.expanduser(store_dir)
if not os.path.isdir(store_dir):
Expand All @@ -37,20 +37,61 @@ def __init__(self, store_dir: str, chat: Chat):
self._graph = nx.DiGraph()

self._db = TinyDB(self._db_path)
self._topics_table = self._db.table('topics')

if not self._topics_table or self._topics_table.all() == []:
self._initialize_topics_table()

def _roots(self) -> List[object]:
graph = self._graph
return [node for node in graph.nodes() if graph.out_degree(node) == 0]

def _leaves(self) -> List[object]:
graph = self._graph
return [node for node in graph.nodes() if graph.in_degree(node) == 0]

def _initialize_topics_table(self):
roots = self._roots()
all_leaves = self._leaves()
root_to_leaves = {root: [] for root in roots}

for leaf in all_leaves:
root = next(nx.descendants(self._graph, leaf).intersection(roots))
root_to_leaves[root].append(leaf)

for root, leaves in root_to_leaves.items():
latest_time = max(self._graph.nodes[leaf]['timestamp'] for leaf in leaves)
self._topics_table.insert({
'root': root,
'leaves': leaves,
'latest_time': latest_time,
'title': None,
'hidden': False
})

def _update_topics_table(self, prompt: Prompt):
if self._graph.in_degree(prompt):
logger.error("Prompt %s not a leaf to update topics table", prompt.hash)

@property
def graph_path(self) -> str:
"""
The path to the graph store file.
"""
return self._graph_path

@property
def db_path(self) -> str:
"""
The path to the object store file.
"""
return self._db_path
if prompt.parent:
topic = next((topic for topic in self._topics_table.all()
if prompt.parent in topic['leaves']), None)
if topic:
topic['leaves'].remove(prompt.parent)
topic['leaves'].append(prompt.hash)
topic['latest_time'] = max(topic['latest_time'], prompt.timestamp)
self._topics_table.update(topic, doc_ids=[topic.doc_id])
else:
logger.error("Parent %s of prompt %s not found in topic leaves",
prompt.parent, prompt.hash)
else:
self._topics_table.insert({
'root': prompt.hash,
'leaves': [prompt.hash],
'latest_time': prompt.timestamp,
'title': None,
'hidden': False
})

def store_prompt(self, prompt: Prompt):
"""
Expand All @@ -71,10 +112,20 @@ def store_prompt(self, prompt: Prompt):
# Add edges for parents and references
if prompt.parent:
if prompt.parent not in self._graph:
logger.warning("Parent %s not found while Prompt %s is stored to graph store.",
prompt.parent, prompt.hash)
logger.error("Parent %s not found while Prompt %s is stored to graph store.",
prompt.parent, prompt.hash)
else:
self._graph.add_edge(prompt.hash, prompt.parent)
self._update_topics_table(prompt)
else:
self._topics_table.insert({
'ancestor_hash': prompt.hash,
'descendant_hash': prompt.hash,
'latest_time': prompt.timestamp,
'title': None,
'hidden': False
})

for reference_hash in prompt.references:
if reference_hash not in self._graph:
logger.warning("Reference %s not found while Prompt %s is stored to graph store.",
Expand Down Expand Up @@ -129,3 +180,17 @@ def select_recent(self, start: int, end: int) -> List[Prompt]:
continue
prompts.append(prompt)
return prompts

@property
def graph_path(self) -> str:
"""
The path to the graph store file.
"""
return self._graph_path

@property
def db_path(self) -> str:
"""
The path to the object store file.
"""
return self._db_path

0 comments on commit 8e0d40a

Please sign in to comment.