Skip to content

Commit

Permalink
Add database migration functionality
Browse files Browse the repository at this point in the history
- Added _migrate_db() for database migrations.
- Checks current database version for updates.
- Replaced 'response' with 'responses' in dictionary.
  • Loading branch information
basicthinker committed Jul 23, 2023
1 parent d95488d commit 1ecafe8
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions devchat/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import List, Dict, Any, Optional
from xml.etree.ElementTree import ParseError
import networkx as nx
from tinydb import TinyDB, where
from tinydb import TinyDB, where, Query
from tinydb.table import Table
from devchat.chat import Chat
from devchat.prompt import Prompt
from devchat.utils import get_logger
Expand Down Expand Up @@ -37,11 +38,34 @@ def __init__(self, store_dir: str, chat: Chat):
self._graph = nx.DiGraph()

self._db = TinyDB(self._db_path)
self._db_meta = self._migrate_db()
self._topics_table = self._db.table('topics')

if not self._topics_table or self._topics_table.all() == []:
if not self._topics_table or not self._topics_table.all():
self._initialize_topics_table()

def _migrate_db(self) -> Table:
"""
Migrate the database to the latest version.
"""
metadata = self._db.table('metadata')

result = metadata.get(where('version').exists())
if not result or result['version'].startswith('0.1.'):
def replace_response():
def transform(doc):
if '_new_messages' not in doc or 'response' not in doc['_new_messages']:
logger.error("Prompt %s does not match '_new_messages.response'",
doc['_hash'])
doc['_new_messages']['responses'] = doc['_new_messages'].pop('response')
return transform

logger.info("Migrating database from %s to 0.2.0", result)
self._db.update(replace_response(),
Query()._new_messages.response.exists()) # pylint: disable=W0212
metadata.insert({'version': '0.2.0'})
return metadata

def _initialize_topics_table(self):
roots = [node for node in self._graph.nodes() if self._graph.out_degree(node) == 0]
for root in roots:
Expand Down

0 comments on commit 1ecafe8

Please sign in to comment.