Skip to content

Commit

Permalink
支持redis cluster and sentinel (#320)
Browse files Browse the repository at this point in the history
支持redis cluster and sentinel
  • Loading branch information
yaojin3616 committed Feb 1, 2024
2 parents e0eab7d + ee1483a commit 2fe6336
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 89 deletions.
16 changes: 13 additions & 3 deletions docker/bisheng/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,20 @@ database_url:
"mysql+pymysql://root:gAAAAABlp4b4c59FeVGF_OQRVf6NOUIGdxq8246EBD-b0hdK_jVKRs1x4PoAn0A6C5S6IiFKmWn0Nm5eBUWu-7jxcqw6TiVjQA==@mysql:3306/bisheng?charset=utf8mb4"

# 缓存配置 redis://[[username]:[password]]@localhost:6379/0
# 普通模式:
#redis_url: "redis://redis:6379/1"
# 集群模式:
redis_url:
sentinel_hosts: [("redis","6379")]
sentinel_master: "mymaster"
sentinel_password: "gAAAAABlp4b4c59FeVGF_OQRVf6NOUIGdxq8246EBD-b0hdK_jVKRs1x4PoAn0A6C5S6IiFKmWn0Nm5eBUWu-7jxcqw6TiVjQA=="
mode: "cluster"
startup_nodes:
- {"host": "192.168.106.115", "port": 6002}
# password: encrypt(gAAAAABlp4b4c59FeVGF_OQRVf6NOUIGdxq8246EBD-b0hdK_jVKRs1x4PoAn0A6C5S6IiFKmWn0Nm5eBUWu-7jxcqw6TiVjQA==)
# #sentinel
# mode: "sentinel"
# sentinel_hosts: [("redis", 6379)]
# sentinel_master: "mymaster"
# sentinel_password: encrypt(gAAAAABlp4b4c59FeVGF_OQRVf6NOUIGdxq8246EBD-b0hdK_jVKRs1x4PoAn0A6C5S6IiFKmWn0Nm5eBUWu-7jxcqw6TiVjQA==)
# db: 1

environment:
env: dev
Expand Down
2 changes: 1 addition & 1 deletion src/backend/bisheng/api/v1/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async def get_template(*, flow_id: str):
with session_getter() as session:
session.add(db_report)
session.commit()
session.refresh()
session.refresh(db_report)
else:
version_key = db_report.newversion_key
res = {
Expand Down
51 changes: 39 additions & 12 deletions src/backend/bisheng/cache/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import redis
from bisheng.settings import settings
from loguru import logger
from redis import ConnectionPool
from redis import ConnectionPool, RedisCluster
from redis.cluster import ClusterNode
from redis.sentinel import Sentinel


Expand All @@ -13,15 +14,23 @@ class RedisClient:
def __init__(self, url, max_connections=10):
# # 哨兵模式
if isinstance(settings.redis_url, Dict):
redis_conf = settings.redis_url
hosts = [eval(x) for x in redis_conf.get('sentinel_hosts')]
sentinel = Sentinel(sentinels=hosts,
socket_timeout=0.1,
password=redis_conf.get('sentinel_password'))
redis_conf = dict(settings.redis_url)
mode = redis_conf.pop('mode', 'sentinel')
if mode == 'cluster':
# 集群模式
if 'startup_nodes' in redis_conf:
redis_conf['startup_nodes'] = [
ClusterNode(node.get('host'), node.get('port'))
for node in redis_conf['startup_nodes']
]
self.connection = RedisCluster(**redis_conf)
return
hosts = [eval(x) for x in redis_conf.pop('sentinel_hosts')]
password = redis_conf.pop('sentinel_password')
master = redis_conf.pop('sentinel_master')
sentinel = Sentinel(sentinels=hosts, socket_timeout=0.1, password=password)
# 获取主节点的连接
self.connection = sentinel.master_for(redis_conf.get('sentinel_master'),
socket_timeout=0.1,
db=1)
self.connection = sentinel.master_for(master, socket_timeout=0.1, **redis_conf)

else:
# 单机模式
Expand All @@ -31,6 +40,7 @@ def __init__(self, url, max_connections=10):
def set(self, key, value, expiration=3600):
try:
if pickled := pickle.dumps(value):
self.cluster_nodes(key)
result = self.connection.setex(key, expiration, pickled)
if not result:
raise ValueError('RedisCache could not set the value.')
Expand All @@ -44,6 +54,7 @@ def set(self, key, value, expiration=3600):
def setNx(self, key, value, expiration=3600):
try:
if pickled := pickle.dumps(value):
self.cluster_nodes(key)
result = self.connection.setnx(key, pickled)
self.connection.expire(key, expiration)
if not result:
Expand All @@ -56,6 +67,7 @@ def setNx(self, key, value, expiration=3600):

def hsetkey(self, name, key, value, expiration=3600):
try:
self.cluster_nodes(key)
r = self.connection.hset(name, key, value)
if expiration:
self.connection.expire(name, expiration)
Expand All @@ -65,6 +77,7 @@ def hsetkey(self, name, key, value, expiration=3600):

def hset(self, name, map: dict, expiration=3600):
try:
self.cluster_nodes(name)
r = self.connection.hset(name, mapping=map)
if expiration:
self.connection.expire(name, expiration)
Expand All @@ -74,25 +87,29 @@ def hset(self, name, map: dict, expiration=3600):

def hget(self, name, key):
try:
self.cluster_nodes(name)
return self.connection.hget(name, key)
finally:
self.close()

def get(self, key):
try:
self.cluster_nodes(key)
value = self.connection.get(key)
return pickle.loads(value) if value else None
finally:
self.close()

def delete(self, key):
try:
self.cluster_nodes(key)
return self.connection.delete(key)
finally:
self.close()

def exists(self, key):
try:
self.cluster_nodes(key)
return self.connection.exists(key)
finally:
self.close()
Expand All @@ -102,19 +119,29 @@ def close(self):

def __contains__(self, key):
"""Check if the key is in the cache."""
self.cluster_nodes(key)
return False if key is None else self.connection.exists(key)

def __getitem__(self, key):
"""Retrieve an item from the cache using the square bracket notation."""
return self.get(key)
self.cluster_nodes(key)
return self.connection.get(key)

def __setitem__(self, key, value):
"""Add an item to the cache using the square bracket notation."""
self.set(key, value)
self.cluster_nodes(key)
self.connection.set(key, value)

def __delitem__(self, key):
"""Remove an item from the cache using the square bracket notation."""
self.delete(key)
self.cluster_nodes(key)
self.connection.delete(key)

def cluster_nodes(self, key):
if isinstance(self.connection,
RedisCluster) and self.connection.get_default_node() is None:
target = self.connection.get_node_from_key(key)
self.connection.set_default_node(target)


# 示例用法
Expand Down
4 changes: 3 additions & 1 deletion src/backend/bisheng/chat/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,15 @@ async def process_file(self, session: ChatManager, client_id: str, chat_id: str,
key = get_cache_key(client_id, chat_id)
langchain_object = session.in_memory_cache.get(key)
input_key = langchain_object.input_keys[0]
input_dict = {k: '' for k in langchain_object.input_keys}

report = ''
logger.info(f'process_file batch_question={batch_question}')
for question in batch_question:
if not question:
continue
payload = {'inputs': {input_key: question}, 'is_begin': False}
input_dict[input_key] = question
payload = {'inputs': input_dict, 'is_begin': False}
start_resp.category == 'question'
await session.send_json(client_id, chat_id, start_resp)
step_resp = ChatResponse(type='end',
Expand Down
2 changes: 2 additions & 0 deletions src/backend/bisheng/default_node.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ prompts:
documentation: "https://python.langchain.com/docs/modules/model_io/models/chat/how_to/prompts"
PromptTemplate:
documentation: "https://python.langchain.com/docs/modules/model_io/prompts/prompt_templates/"
MessagesPlaceholder:
documentation: ""
textsplitters:
CharacterTextSplitter:
documentation: "https://python.langchain.com/docs/modules/data_connection/document_transformers/text_splitters/character_text_splitter"
Expand Down
87 changes: 15 additions & 72 deletions src/backend/bisheng/interface/initialize/loading.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import contextlib
import json
from typing import Any, Callable, Dict, List, Sequence, Type
from typing import Any, Callable, Dict, Sequence, Type

from bisheng.cache.utils import file_download
from bisheng.chat.config import ChatConfig
Expand All @@ -9,6 +8,8 @@
from bisheng.interface.custom_lists import CUSTOM_NODES
from bisheng.interface.importing.utils import get_function, import_by_type
from bisheng.interface.initialize.llm import initialize_vertexai
from bisheng.interface.initialize.utils import (handle_format_kwargs, handle_node_type,
handle_partial_variables)
from bisheng.interface.initialize.vector_store import vecstore_initializer
from bisheng.interface.output_parsers.base import output_parser_creator
from bisheng.interface.retrievers.base import retriever_creator
Expand All @@ -19,15 +20,13 @@
from bisheng.utils import validate
from bisheng.utils.constants import NODE_ID_DICT, PRESET_QUESTION
from bisheng_langchain.vectorstores import VectorStoreFilterRetriever
from langchain.agents import ZeroShotAgent
from langchain.agents import agent as agent_module
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.agents.tools import BaseTool
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.document_loaders.base import BaseLoader
from langchain.schema import BaseOutputParser, Document
from langchain.vectorstores.base import VectorStore
from loguru import logger
from pydantic import ValidationError, create_model
Expand Down Expand Up @@ -344,78 +343,22 @@ def instantiate_agent(node_type, class_object: Type[agent_module.Agent], params:


def instantiate_prompt(node_type, class_object, params: Dict, param_id_dict: Dict):

if node_type == 'ZeroShotPrompt':
if 'tools' not in params:
params['tools'] = []
return ZeroShotAgent.create_prompt(**params)
elif 'MessagePromptTemplate' in node_type:
# Then we only need the template
from_template_params = {'template': params.pop('prompt', params.pop('template', ''))}

if not from_template_params.get('template'):
raise ValueError('Prompt template is required')
prompt = class_object.from_template(**from_template_params)

elif node_type == 'ChatPromptTemplate':
prompt = class_object.from_messages(**params)
else:
prompt = class_object(**params)
params, prompt = handle_node_type(node_type, class_object, params)
format_kwargs = handle_format_kwargs(prompt, params)
# Now we'll use partial_format to format the prompt
if format_kwargs:
prompt = handle_partial_variables(prompt, format_kwargs)

no_human_input = set(param_id_dict.keys())
human_input = set(prompt.input_variables).difference(no_human_input)
order_input = list(human_input) + list(set(prompt.input_variables) & no_human_input)
prompt.input_variables = order_input
format_kwargs: Dict[str, Any] = {}
for input_variable in prompt.input_variables:
if input_variable in params:
variable = params[input_variable]
if isinstance(variable, str):
format_kwargs[input_variable] = variable
elif isinstance(variable, BaseOutputParser) and hasattr(variable,
'get_format_instructions'):
format_kwargs[input_variable] = variable.get_format_instructions()
elif isinstance(variable, dict):
# variable node
if len(variable) == 0:
format_kwargs[input_variable] = ''
continue
elif len(variable) != 1:
raise ValueError(f'VariableNode contains multi-key {variable.keys()}')
format_kwargs[input_variable] = list(variable.values())[0]
elif isinstance(variable, List) and all(
isinstance(item, Document) for item in variable):
# Format document to contain page_content and metadata
# as one string separated by a newline
if len(variable) > 1:
content = '\n'.join(
[item.page_content for item in variable if item.page_content])
else:
if not variable:
format_kwargs[input_variable] = ''
continue
content = variable[0].page_content
# content could be a json list of strings
with contextlib.suppress(json.JSONDecodeError):
content = json.loads(content)
if isinstance(content, list):
content = ','.join([str(item) for item in content])
format_kwargs[input_variable] = content
# handle_keys will be a list but it does not exist yet
# so we need to create it

if (isinstance(variable, List) and all(
isinstance(item, Document)
for item in variable)) or (isinstance(variable, BaseOutputParser)
and hasattr(variable, 'get_format_instructions')):
if 'handle_keys' not in format_kwargs:
format_kwargs['handle_keys'] = []

# Add the handle_keys to the list
format_kwargs['handle_keys'].append(input_variable)

# from langchain.chains.router.llm_router import RouterOutputParser
# prompt.output_parser = RouterOutputParser()
if len(order_input) > 1:
# if node_type == 'ChatPromptTemplate':

if hasattr(prompt, 'prompt') and hasattr(prompt.prompt, 'input_variables'):
prompt.prompt.input_variables = order_input
elif hasattr(prompt, 'input_variables'):
prompt.input_variables = order_input
return prompt, format_kwargs


Expand Down
Loading

0 comments on commit 2fe6336

Please sign in to comment.