Skip to content

Commit

Permalink
Feat (#1951)
Browse files Browse the repository at this point in the history
* 知识库支持子目录(不包括temp和tmp开头的目录),文件相对路径总长度不可超过255

* init_database.py 增加 --import-db 参数,在版本升级时,如果 info.db
表结构发生变化,但向量库无需重建,可以在重建数据库后,使用本参数从旧的数据库中导入信息
  • Loading branch information
liunux4odoo committed Nov 2, 2023
1 parent d8e15b5 commit 554122f
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 6 deletions.
9 changes: 8 additions & 1 deletion init_database.py
@@ -1,6 +1,7 @@
import sys
sys.path.append(".")
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files
from server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
folder2db, prune_db_docs, prune_folder_files)
from configs.model_config import NLTK_DATA_PATH, EMBEDDING_MODEL
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
Expand Down Expand Up @@ -28,6 +29,10 @@
action="store_true",
help=("drop the database tables before recreate vector stores")
)
parser.add_argument(
"--import-db",
help="import tables from specified sqlite database"
)
parser.add_argument(
"-u",
"--update-in-db",
Expand Down Expand Up @@ -97,6 +102,8 @@
if args.recreate_vs:
print("recreating all vector stores")
folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model)
elif args.import_db:
import_from_db(args.import_db)
elif args.update_in_db:
folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model)
elif args.increament:
Expand Down
4 changes: 2 additions & 2 deletions server/db/base.py
@@ -1,5 +1,5 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
from sqlalchemy.orm import sessionmaker

from configs import SQLALCHEMY_DATABASE_URI
Expand All @@ -13,4 +13,4 @@

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

Base = declarative_base()
Base: DeclarativeMeta = declarative_base()
3 changes: 2 additions & 1 deletion server/db/session.py
@@ -1,10 +1,11 @@
from functools import wraps
from contextlib import contextmanager
from server.db.base import SessionLocal
from sqlalchemy.orm import Session


@contextmanager
def session_scope():
def session_scope() -> Session:
"""上下文管理器用于自动获取 Session, 避免错误"""
session = SessionLocal()
try:
Expand Down
45 changes: 43 additions & 2 deletions server/knowledge_base/migrate.py
Expand Up @@ -5,10 +5,12 @@
list_files_from_folder,files2docs_in_thread,
KnowledgeFile,)
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_file_repository import add_file_to_db
from server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported
from server.db.base import Base, engine
from server.db.session import session_scope
import os
from typing import Literal, Any, List
from dateutil.parser import parse
from typing import Literal, List


def create_tables():
Expand All @@ -20,6 +22,45 @@ def reset_tables():
create_tables()


def import_from_db(
sqlite_path: str = None,
# csv_path: str = None,
) -> bool:
'''
在知识库与向量库无变化的情况下,从备份数据库中导入数据到 info.db。
适用于版本升级时,info.db 结构变化,但无需重新向量化的情况。
请确保两边数据库表名一致,需要导入的字段名一致
当前仅支持 sqlite
'''
import sqlite3 as sql
from pprint import pprint

models = list(Base.registry.mappers)

try:
con = sql.connect(sqlite_path)
con.row_factory = sql.Row
cur = con.cursor()
tables = [x["name"] for x in cur.execute("select name from sqlite_master where type='table'").fetchall()]
for model in models:
table = model.local_table.fullname
if table not in tables:
continue
print(f"processing table: {table}")
with session_scope() as session:
for row in cur.execute(f"select * from {table}").fetchall():
data = {k: row[k] for k in row.keys() if k in model.columns}
if "create_time" in data:
data["create_time"] = parse(data["create_time"])
pprint(data)
session.add(model.class_(**data))
con.close()
return True
except Exception as e:
print(f"无法读取备份数据库:{sqlite_path}。错误信息:{e}")
return False


def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]:
kb_files = []
for file in files:
Expand Down

0 comments on commit 554122f

Please sign in to comment.