diff --git a/docker/bisheng/config/config.yaml b/docker/bisheng/config/config.yaml index b975ce07..941bd89d 100644 --- a/docker/bisheng/config/config.yaml +++ b/docker/bisheng/config/config.yaml @@ -3,10 +3,20 @@ database_url: "mysql+pymysql://root:gAAAAABlp4b4c59FeVGF_OQRVf6NOUIGdxq8246EBD-b0hdK_jVKRs1x4PoAn0A6C5S6IiFKmWn0Nm5eBUWu-7jxcqw6TiVjQA==@mysql:3306/bisheng?charset=utf8mb4" # 缓存配置 redis://[[username]:[password]]@localhost:6379/0 +# 普通模式: +#redis_url: "redis://redis:6379/1" +# 集群模式: redis_url: - sentinel_hosts: [("redis","6379")] - sentinel_master: "mymaster" - sentinel_password: "gAAAAABlp4b4c59FeVGF_OQRVf6NOUIGdxq8246EBD-b0hdK_jVKRs1x4PoAn0A6C5S6IiFKmWn0Nm5eBUWu-7jxcqw6TiVjQA==" + mode: "cluster" + startup_nodes: + - {"host": "192.168.106.115", "port": 6002} + # password: encrypt(gAAAAABlp4b4c59FeVGF_OQRVf6NOUIGdxq8246EBD-b0hdK_jVKRs1x4PoAn0A6C5S6IiFKmWn0Nm5eBUWu-7jxcqw6TiVjQA==) + # #sentinel + # mode: "sentinel" + # sentinel_hosts: [("redis", 6379)] + # sentinel_master: "mymaster" + # sentinel_password: encrypt(gAAAAABlp4b4c59FeVGF_OQRVf6NOUIGdxq8246EBD-b0hdK_jVKRs1x4PoAn0A6C5S6IiFKmWn0Nm5eBUWu-7jxcqw6TiVjQA==) + # db: 1 environment: env: dev diff --git a/src/backend/bisheng/api/v1/report.py b/src/backend/bisheng/api/v1/report.py index a9d2be8a..e4abbd8d 100644 --- a/src/backend/bisheng/api/v1/report.py +++ b/src/backend/bisheng/api/v1/report.py @@ -65,7 +65,7 @@ async def get_template(*, flow_id: str): with session_getter() as session: session.add(db_report) session.commit() - session.refresh() + session.refresh(db_report) else: version_key = db_report.newversion_key res = { diff --git a/src/backend/bisheng/cache/redis.py b/src/backend/bisheng/cache/redis.py index d93afdfc..adfd4d4b 100644 --- a/src/backend/bisheng/cache/redis.py +++ b/src/backend/bisheng/cache/redis.py @@ -4,7 +4,8 @@ import redis from bisheng.settings import settings from loguru import logger -from redis import ConnectionPool +from redis import ConnectionPool, RedisCluster +from redis.cluster import ClusterNode from redis.sentinel import Sentinel @@ -13,15 +14,23 @@ class RedisClient: def __init__(self, url, max_connections=10): # # 哨兵模式 if isinstance(settings.redis_url, Dict): - redis_conf = settings.redis_url - hosts = [eval(x) for x in redis_conf.get('sentinel_hosts')] - sentinel = Sentinel(sentinels=hosts, - socket_timeout=0.1, - password=redis_conf.get('sentinel_password')) + redis_conf = dict(settings.redis_url) + mode = redis_conf.pop('mode', 'sentinel') + if mode == 'cluster': + # 集群模式 + if 'startup_nodes' in redis_conf: + redis_conf['startup_nodes'] = [ + ClusterNode(node.get('host'), node.get('port')) + for node in redis_conf['startup_nodes'] + ] + self.connection = RedisCluster(**redis_conf) + return + hosts = [eval(x) for x in redis_conf.pop('sentinel_hosts')] + password = redis_conf.pop('sentinel_password') + master = redis_conf.pop('sentinel_master') + sentinel = Sentinel(sentinels=hosts, socket_timeout=0.1, password=password) # 获取主节点的连接 - self.connection = sentinel.master_for(redis_conf.get('sentinel_master'), - socket_timeout=0.1, - db=1) + self.connection = sentinel.master_for(master, socket_timeout=0.1, **redis_conf) else: # 单机模式 @@ -31,6 +40,7 @@ def __init__(self, url, max_connections=10): def set(self, key, value, expiration=3600): try: if pickled := pickle.dumps(value): + self.cluster_nodes(key) result = self.connection.setex(key, expiration, pickled) if not result: raise ValueError('RedisCache could not set the value.') @@ -44,6 +54,7 @@ def set(self, key, value, expiration=3600): def setNx(self, key, value, expiration=3600): try: if pickled := pickle.dumps(value): + self.cluster_nodes(key) result = self.connection.setnx(key, pickled) self.connection.expire(key, expiration) if not result: @@ -56,6 +67,7 @@ def setNx(self, key, value, expiration=3600): def hsetkey(self, name, key, value, expiration=3600): try: + self.cluster_nodes(key) r = self.connection.hset(name, key, value) if expiration: self.connection.expire(name, expiration) @@ -65,6 +77,7 @@ def hsetkey(self, name, key, value, expiration=3600): def hset(self, name, map: dict, expiration=3600): try: + self.cluster_nodes(name) r = self.connection.hset(name, mapping=map) if expiration: self.connection.expire(name, expiration) @@ -74,12 +87,14 @@ def hset(self, name, map: dict, expiration=3600): def hget(self, name, key): try: + self.cluster_nodes(name) return self.connection.hget(name, key) finally: self.close() def get(self, key): try: + self.cluster_nodes(key) value = self.connection.get(key) return pickle.loads(value) if value else None finally: @@ -87,12 +102,14 @@ def get(self, key): def delete(self, key): try: + self.cluster_nodes(key) return self.connection.delete(key) finally: self.close() def exists(self, key): try: + self.cluster_nodes(key) return self.connection.exists(key) finally: self.close() @@ -102,19 +119,29 @@ def close(self): def __contains__(self, key): """Check if the key is in the cache.""" + self.cluster_nodes(key) return False if key is None else self.connection.exists(key) def __getitem__(self, key): """Retrieve an item from the cache using the square bracket notation.""" - return self.get(key) + self.cluster_nodes(key) + return self.connection.get(key) def __setitem__(self, key, value): """Add an item to the cache using the square bracket notation.""" - self.set(key, value) + self.cluster_nodes(key) + self.connection.set(key, value) def __delitem__(self, key): """Remove an item from the cache using the square bracket notation.""" - self.delete(key) + self.cluster_nodes(key) + self.connection.delete(key) + + def cluster_nodes(self, key): + if isinstance(self.connection, + RedisCluster) and self.connection.get_default_node() is None: + target = self.connection.get_node_from_key(key) + self.connection.set_default_node(target) # 示例用法 diff --git a/src/backend/bisheng/chat/handlers.py b/src/backend/bisheng/chat/handlers.py index 09a78764..15e1f2a7 100644 --- a/src/backend/bisheng/chat/handlers.py +++ b/src/backend/bisheng/chat/handlers.py @@ -216,13 +216,15 @@ async def process_file(self, session: ChatManager, client_id: str, chat_id: str, key = get_cache_key(client_id, chat_id) langchain_object = session.in_memory_cache.get(key) input_key = langchain_object.input_keys[0] + input_dict = {k: '' for k in langchain_object.input_keys} report = '' logger.info(f'process_file batch_question={batch_question}') for question in batch_question: if not question: continue - payload = {'inputs': {input_key: question}, 'is_begin': False} + input_dict[input_key] = question + payload = {'inputs': input_dict, 'is_begin': False} start_resp.category == 'question' await session.send_json(client_id, chat_id, start_resp) step_resp = ChatResponse(type='end', diff --git a/src/backend/bisheng/default_node.yaml b/src/backend/bisheng/default_node.yaml index 3851d484..0d64b429 100644 --- a/src/backend/bisheng/default_node.yaml +++ b/src/backend/bisheng/default_node.yaml @@ -235,6 +235,8 @@ prompts: documentation: "https://python.langchain.com/docs/modules/model_io/models/chat/how_to/prompts" PromptTemplate: documentation: "https://python.langchain.com/docs/modules/model_io/prompts/prompt_templates/" + MessagesPlaceholder: + documentation: "" textsplitters: CharacterTextSplitter: documentation: "https://python.langchain.com/docs/modules/data_connection/document_transformers/text_splitters/character_text_splitter" diff --git a/src/backend/bisheng/interface/initialize/loading.py b/src/backend/bisheng/interface/initialize/loading.py index e8155793..372418e1 100644 --- a/src/backend/bisheng/interface/initialize/loading.py +++ b/src/backend/bisheng/interface/initialize/loading.py @@ -1,6 +1,5 @@ -import contextlib import json -from typing import Any, Callable, Dict, List, Sequence, Type +from typing import Any, Callable, Dict, Sequence, Type from bisheng.cache.utils import file_download from bisheng.chat.config import ChatConfig @@ -9,6 +8,8 @@ from bisheng.interface.custom_lists import CUSTOM_NODES from bisheng.interface.importing.utils import get_function, import_by_type from bisheng.interface.initialize.llm import initialize_vertexai +from bisheng.interface.initialize.utils import (handle_format_kwargs, handle_node_type, + handle_partial_variables) from bisheng.interface.initialize.vector_store import vecstore_initializer from bisheng.interface.output_parsers.base import output_parser_creator from bisheng.interface.retrievers.base import retriever_creator @@ -19,7 +20,6 @@ from bisheng.utils import validate from bisheng.utils.constants import NODE_ID_DICT, PRESET_QUESTION from bisheng_langchain.vectorstores import VectorStoreFilterRetriever -from langchain.agents import ZeroShotAgent from langchain.agents import agent as agent_module from langchain.agents.agent import AgentExecutor from langchain.agents.agent_toolkits.base import BaseToolkit @@ -27,7 +27,6 @@ from langchain.base_language import BaseLanguageModel from langchain.chains.base import Chain from langchain.document_loaders.base import BaseLoader -from langchain.schema import BaseOutputParser, Document from langchain.vectorstores.base import VectorStore from loguru import logger from pydantic import ValidationError, create_model @@ -344,78 +343,22 @@ def instantiate_agent(node_type, class_object: Type[agent_module.Agent], params: def instantiate_prompt(node_type, class_object, params: Dict, param_id_dict: Dict): - - if node_type == 'ZeroShotPrompt': - if 'tools' not in params: - params['tools'] = [] - return ZeroShotAgent.create_prompt(**params) - elif 'MessagePromptTemplate' in node_type: - # Then we only need the template - from_template_params = {'template': params.pop('prompt', params.pop('template', ''))} - - if not from_template_params.get('template'): - raise ValueError('Prompt template is required') - prompt = class_object.from_template(**from_template_params) - - elif node_type == 'ChatPromptTemplate': - prompt = class_object.from_messages(**params) - else: - prompt = class_object(**params) + params, prompt = handle_node_type(node_type, class_object, params) + format_kwargs = handle_format_kwargs(prompt, params) + # Now we'll use partial_format to format the prompt + if format_kwargs: + prompt = handle_partial_variables(prompt, format_kwargs) no_human_input = set(param_id_dict.keys()) human_input = set(prompt.input_variables).difference(no_human_input) order_input = list(human_input) + list(set(prompt.input_variables) & no_human_input) - prompt.input_variables = order_input - format_kwargs: Dict[str, Any] = {} - for input_variable in prompt.input_variables: - if input_variable in params: - variable = params[input_variable] - if isinstance(variable, str): - format_kwargs[input_variable] = variable - elif isinstance(variable, BaseOutputParser) and hasattr(variable, - 'get_format_instructions'): - format_kwargs[input_variable] = variable.get_format_instructions() - elif isinstance(variable, dict): - # variable node - if len(variable) == 0: - format_kwargs[input_variable] = '' - continue - elif len(variable) != 1: - raise ValueError(f'VariableNode contains multi-key {variable.keys()}') - format_kwargs[input_variable] = list(variable.values())[0] - elif isinstance(variable, List) and all( - isinstance(item, Document) for item in variable): - # Format document to contain page_content and metadata - # as one string separated by a newline - if len(variable) > 1: - content = '\n'.join( - [item.page_content for item in variable if item.page_content]) - else: - if not variable: - format_kwargs[input_variable] = '' - continue - content = variable[0].page_content - # content could be a json list of strings - with contextlib.suppress(json.JSONDecodeError): - content = json.loads(content) - if isinstance(content, list): - content = ','.join([str(item) for item in content]) - format_kwargs[input_variable] = content - # handle_keys will be a list but it does not exist yet - # so we need to create it - - if (isinstance(variable, List) and all( - isinstance(item, Document) - for item in variable)) or (isinstance(variable, BaseOutputParser) - and hasattr(variable, 'get_format_instructions')): - if 'handle_keys' not in format_kwargs: - format_kwargs['handle_keys'] = [] - - # Add the handle_keys to the list - format_kwargs['handle_keys'].append(input_variable) - - # from langchain.chains.router.llm_router import RouterOutputParser - # prompt.output_parser = RouterOutputParser() + if len(order_input) > 1: + # if node_type == 'ChatPromptTemplate': + + if hasattr(prompt, 'prompt') and hasattr(prompt.prompt, 'input_variables'): + prompt.prompt.input_variables = order_input + elif hasattr(prompt, 'input_variables'): + prompt.input_variables = order_input return prompt, format_kwargs diff --git a/src/backend/bisheng/interface/initialize/utils.py b/src/backend/bisheng/interface/initialize/utils.py new file mode 100644 index 00000000..0c456a2e --- /dev/null +++ b/src/backend/bisheng/interface/initialize/utils.py @@ -0,0 +1,110 @@ +import contextlib +import json +from typing import Any, Dict, List + +import orjson +from bisheng.database.models.base import orjson_dumps +from langchain.agents import ZeroShotAgent +from langchain.schema import BaseOutputParser, Document + + +def handle_node_type(node_type, class_object, params: Dict): + if node_type == 'ZeroShotPrompt': + params = check_tools_in_params(params) + prompt = ZeroShotAgent.create_prompt(**params) + elif 'MessagePromptTemplate' in node_type: + prompt = instantiate_from_template(class_object, params) + elif node_type == 'ChatPromptTemplate': + prompt = class_object.from_messages(**params) + elif hasattr(class_object, 'from_template') and params.get('template'): + prompt = class_object.from_template(template=params.pop('template')) + else: + prompt = class_object(**params) + return params, prompt + + +def check_tools_in_params(params: Dict): + if 'tools' not in params: + params['tools'] = [] + return params + + +def instantiate_from_template(class_object, params: Dict): + from_template_params = {'template': params.pop('prompt', params.pop('template', ''))} + if not from_template_params.get('template'): + raise ValueError('Prompt template is required') + return class_object.from_template(**from_template_params) + + +def handle_format_kwargs(prompt, params: Dict): + format_kwargs: Dict[str, Any] = {} + for input_variable in prompt.input_variables: + if input_variable in params: + format_kwargs = handle_variable(params, input_variable, format_kwargs) + return format_kwargs + + +def handle_partial_variables(prompt, format_kwargs: Dict): + partial_variables = format_kwargs.copy() + partial_variables = {key: value for key, value in partial_variables.items() if value} + # Remove handle_keys otherwise LangChain raises an error + partial_variables.pop('handle_keys', None) + if partial_variables and hasattr(prompt, 'partial'): + return prompt.partial(**partial_variables) + return prompt + + +def handle_variable(params: Dict, input_variable: str, format_kwargs: Dict): + variable = params[input_variable] + if isinstance(variable, str): + format_kwargs[input_variable] = variable + elif isinstance(variable, BaseOutputParser) and hasattr(variable, 'get_format_instructions'): + format_kwargs[input_variable] = variable.get_format_instructions() + elif is_instance_of_list_or_document(variable): + format_kwargs = format_document(variable, input_variable, format_kwargs) + if needs_handle_keys(variable): + format_kwargs = add_handle_keys(input_variable, format_kwargs) + return format_kwargs + + +def is_instance_of_list_or_document(variable): + return (isinstance(variable, List) and all(isinstance(item, Document) for item in variable) + or isinstance(variable, Document)) + + +def format_document(variable, input_variable: str, format_kwargs: Dict): + variable = variable if isinstance(variable, List) else [variable] + content = format_content(variable) + format_kwargs[input_variable] = content + return format_kwargs + + +def format_content(variable): + if len(variable) > 1: + return '\n'.join([item.page_content for item in variable if item.page_content]) + elif len(variable) == 1: + content = variable[0].page_content + return try_to_load_json(content) + return '' + + +def try_to_load_json(content): + with contextlib.suppress(json.JSONDecodeError): + content = orjson.loads(content) + if isinstance(content, list): + content = ','.join([str(item) for item in content]) + else: + content = orjson_dumps(content) + return content + + +def needs_handle_keys(variable): + return is_instance_of_list_or_document(variable) or (isinstance( + variable, BaseOutputParser) and hasattr(variable, 'get_format_instructions')) + + +def add_handle_keys(input_variable: str, format_kwargs: Dict): + if 'handle_keys' not in format_kwargs: + format_kwargs['handle_keys'] = [] + format_kwargs['handle_keys'].append(input_variable) + return format_kwargs diff --git a/src/backend/bisheng/template/frontend_node/documentloaders.py b/src/backend/bisheng/template/frontend_node/documentloaders.py index ad3136bf..eb29ea67 100644 --- a/src/backend/bisheng/template/frontend_node/documentloaders.py +++ b/src/backend/bisheng/template/frontend_node/documentloaders.py @@ -169,6 +169,15 @@ def add_extra_fields(self) -> None: display_name='unstructured_api_url', advanced=False, )) + self.template.add_field( + TemplateField( + field_type='dict', + required=True, + show=True, + name='kwargs', + display_name='kwargs', + advanced=False, + )) self.template.add_field(self.file_path_templates[self.template.type_name]) elif self.template.type_name in {'UniversalKVLoader'}: self.template.add_field( @@ -320,6 +329,7 @@ def format_field(field: TemplateField, name: Optional[str] = None) -> None: field.is_list = False if name == 'ElemUnstructuredLoaderV0' and field.name == 'kwargs': field.show = True + field.advanced = False def build_pdf_semantic_loader_fields():