From 3a4c5994a532d8ad388a8642695cbf2a2c4b6d58 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Sat, 1 Jun 2024 12:01:56 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E9=9B=86=E6=88=90model=5Fproviders?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chatchat-server/chatchat/init_database.py | 126 ++++++++++++++---- 1 file changed, 98 insertions(+), 28 deletions(-) diff --git a/libs/chatchat-server/chatchat/init_database.py b/libs/chatchat-server/chatchat/init_database.py index 67c289637..7a1baec75 100644 --- a/libs/chatchat-server/chatchat/init_database.py +++ b/libs/chatchat-server/chatchat/init_database.py @@ -1,14 +1,42 @@ -import sys -sys.path.append("chatchat") +# Description: 初始化数据库,包括创建表、导入数据、更新向量空间等操作 +from typing import Dict from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db, folder2db, prune_db_docs, prune_folder_files) -from chatchat.configs import DEFAULT_EMBEDDING_MODEL +from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS +import multiprocessing as mp +import logging +logger = logging.getLogger(__name__) + from datetime import datetime +def run_init_model_provider( + model_platforms_shard: Dict, + started_event: mp.Event = None, + model_providers_cfg_path: str = None, + provider_host: str = None, + provider_port: int = None): + from chatchat.init_server import init_server + from chatchat.configs import (MODEL_PROVIDERS_CFG_PATH_CONFIG, + MODEL_PROVIDERS_CFG_HOST, + MODEL_PROVIDERS_CFG_PORT) + if model_providers_cfg_path is None: + model_providers_cfg_path = MODEL_PROVIDERS_CFG_PATH_CONFIG + if provider_host is None: + provider_host = MODEL_PROVIDERS_CFG_HOST + if provider_port is None: + provider_port = MODEL_PROVIDERS_CFG_PORT + + init_server(model_platforms_shard=model_platforms_shard, + started_event=started_event, + model_providers_cfg_path=model_providers_cfg_path, + provider_host=provider_host, + provider_port=provider_port) + + if __name__ == "__main__": import argparse - + parser = argparse.ArgumentParser(description="please specify only one operate method once time.") parser.add_argument( @@ -92,27 +120,69 @@ args = parser.parse_args() start_time = datetime.now() - if args.create_tables: - create_tables() # confirm tables exist - - if args.clear_tables: - reset_tables() - print("database tables reset") - - if args.recreate_vs: - create_tables() - print("recreating all vector stores") - folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model) - elif args.import_db: - import_from_db(args.import_db) - elif args.update_in_db: - folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model) - elif args.increment: - folder2db(kb_names=args.kb_name, mode="increment", embed_model=args.embed_model) - elif args.prune_db: - prune_db_docs(args.kb_name) - elif args.prune_folder: - prune_folder_files(args.kb_name) - - end_time = datetime.now() - print(f"总计用时: {end_time-start_time}") + mp.set_start_method("spawn") + manager = mp.Manager() + + # 定义全局配置变量,使用 Manager 创建共享字典 + model_platforms_shard = manager.dict() + model_providers_started = manager.Event() + processes = {} + process = mp.Process( + target=run_init_model_provider, + name=f"Model providers Server", + kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=model_providers_started), + daemon=True, + ) + processes["model_providers"] = process + try: + # 保证任务收到SIGINT后,能够正常退出 + if p := processes.get("model_providers"): + p.start() + p.name = f"{p.name} ({p.pid})" + model_providers_started.wait() # 等待model_providers启动完成 + MODEL_PLATFORMS.extend(model_platforms_shard['provider_platforms']) + logger.info(f"Api MODEL_PLATFORMS: {MODEL_PLATFORMS}") + + + if args.create_tables: + create_tables() # confirm tables exist + + if args.clear_tables: + reset_tables() + print("database tables reset") + + if args.recreate_vs: + create_tables() + print("recreating all vector stores") + folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model) + elif args.import_db: + import_from_db(args.import_db) + elif args.update_in_db: + folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model) + elif args.increment: + folder2db(kb_names=args.kb_name, mode="increment", embed_model=args.embed_model) + elif args.prune_db: + prune_db_docs(args.kb_name) + elif args.prune_folder: + prune_folder_files(args.kb_name) + + end_time = datetime.now() + print(f"总计用时: {end_time-start_time}") + except Exception as e: + logger.error(e) + logger.warning("Caught KeyboardInterrupt! Setting stop event...") + finally: + + for p in processes.values(): + logger.warning("Sending SIGKILL to %s", p) + # Queues and other inter-process communication primitives can break when + # process is killed, but we don't care here + + if isinstance(p, dict): + for process in p.values(): + process.kill() + else: + p.kill() + + for p in processes.values(): + logger.info("Process status: %s", p) From 487044a13fdfdbe6fd18039a0746db30d737fbe2 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Sat, 1 Jun 2024 12:16:56 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E5=85=B3=E9=97=AD=E5=AE=88=E6=8A=A4?= =?UTF-8?q?=E8=BF=9B=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- libs/chatchat-server/chatchat/startup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/chatchat-server/chatchat/startup.py b/libs/chatchat-server/chatchat/startup.py index 01a7b10c3..992864b3f 100644 --- a/libs/chatchat-server/chatchat/startup.py +++ b/libs/chatchat-server/chatchat/startup.py @@ -284,7 +284,7 @@ def process_count(): target=run_api_server, name=f"API Server", kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=api_started, run_mode=run_mode), - daemon=True, + daemon=False, ) processes["api"] = process @@ -367,4 +367,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main()