Skip to content

Commit

Permalink
Merge pull request #4122 from chatchat-space/dev_init_database_providers
Browse files Browse the repository at this point in the history
Dev init database providers关闭守护进程
  • Loading branch information
glide-the committed Jun 2, 2024
2 parents 10c5dcf + 487044a commit bc6832b
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 30 deletions.
126 changes: 98 additions & 28 deletions libs/chatchat-server/chatchat/init_database.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,42 @@
import sys
sys.path.append("chatchat")
# Description: 初始化数据库,包括创建表、导入数据、更新向量空间等操作
from typing import Dict
from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
folder2db, prune_db_docs, prune_folder_files)
from chatchat.configs import DEFAULT_EMBEDDING_MODEL
from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS
import multiprocessing as mp
import logging
logger = logging.getLogger(__name__)

from datetime import datetime


def run_init_model_provider(
model_platforms_shard: Dict,
started_event: mp.Event = None,
model_providers_cfg_path: str = None,
provider_host: str = None,
provider_port: int = None):
from chatchat.init_server import init_server
from chatchat.configs import (MODEL_PROVIDERS_CFG_PATH_CONFIG,
MODEL_PROVIDERS_CFG_HOST,
MODEL_PROVIDERS_CFG_PORT)
if model_providers_cfg_path is None:
model_providers_cfg_path = MODEL_PROVIDERS_CFG_PATH_CONFIG
if provider_host is None:
provider_host = MODEL_PROVIDERS_CFG_HOST
if provider_port is None:
provider_port = MODEL_PROVIDERS_CFG_PORT

init_server(model_platforms_shard=model_platforms_shard,
started_event=started_event,
model_providers_cfg_path=model_providers_cfg_path,
provider_host=provider_host,
provider_port=provider_port)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="please specify only one operate method once time.")

parser.add_argument(
Expand Down Expand Up @@ -92,27 +120,69 @@
args = parser.parse_args()
start_time = datetime.now()

if args.create_tables:
create_tables() # confirm tables exist

if args.clear_tables:
reset_tables()
print("database tables reset")

if args.recreate_vs:
create_tables()
print("recreating all vector stores")
folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model)
elif args.import_db:
import_from_db(args.import_db)
elif args.update_in_db:
folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model)
elif args.increment:
folder2db(kb_names=args.kb_name, mode="increment", embed_model=args.embed_model)
elif args.prune_db:
prune_db_docs(args.kb_name)
elif args.prune_folder:
prune_folder_files(args.kb_name)

end_time = datetime.now()
print(f"总计用时: {end_time-start_time}")
mp.set_start_method("spawn")
manager = mp.Manager()

# 定义全局配置变量,使用 Manager 创建共享字典
model_platforms_shard = manager.dict()
model_providers_started = manager.Event()
processes = {}
process = mp.Process(
target=run_init_model_provider,
name=f"Model providers Server",
kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=model_providers_started),
daemon=True,
)
processes["model_providers"] = process
try:
# 保证任务收到SIGINT后,能够正常退出
if p := processes.get("model_providers"):
p.start()
p.name = f"{p.name} ({p.pid})"
model_providers_started.wait() # 等待model_providers启动完成
MODEL_PLATFORMS.extend(model_platforms_shard['provider_platforms'])
logger.info(f"Api MODEL_PLATFORMS: {MODEL_PLATFORMS}")


if args.create_tables:
create_tables() # confirm tables exist

if args.clear_tables:
reset_tables()
print("database tables reset")

if args.recreate_vs:
create_tables()
print("recreating all vector stores")
folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model)
elif args.import_db:
import_from_db(args.import_db)
elif args.update_in_db:
folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model)
elif args.increment:
folder2db(kb_names=args.kb_name, mode="increment", embed_model=args.embed_model)
elif args.prune_db:
prune_db_docs(args.kb_name)
elif args.prune_folder:
prune_folder_files(args.kb_name)

end_time = datetime.now()
print(f"总计用时: {end_time-start_time}")
except Exception as e:
logger.error(e)
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
finally:

for p in processes.values():
logger.warning("Sending SIGKILL to %s", p)
# Queues and other inter-process communication primitives can break when
# process is killed, but we don't care here

if isinstance(p, dict):
for process in p.values():
process.kill()
else:
p.kill()

for p in processes.values():
logger.info("Process status: %s", p)
4 changes: 2 additions & 2 deletions libs/chatchat-server/chatchat/startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def process_count():
target=run_api_server,
name=f"API Server",
kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=api_started, run_mode=run_mode),
daemon=True,
daemon=False,
)
processes["api"] = process

Expand Down Expand Up @@ -367,4 +367,4 @@ def main():


if __name__ == "__main__":
main()
main()

0 comments on commit bc6832b

Please sign in to comment.