diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 25fcf12b3..be1f4d185 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -8,7 +8,7 @@ MODEL_ROOT_PATH = "" # 选用的 Embedding 名称 EMBEDDING_MODEL = "bge-large-zh-v1.5" -# Embedding 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。 +# Embedding 模型运行设备。设为 "auto" 会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。 EMBEDDING_DEVICE = "auto" # 选用的reranker模型 @@ -26,11 +26,11 @@ EMBEDDING_MODEL_OUTPUT_PATH = "output" # 在这里,我们使用目前主流的两个离线模型,其中,chatglm3-6b 为默认加载模型。 # 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。 -LLM_MODELS = ["zhipu-api"] +LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"] Agent_MODEL = None # LLM 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。 -LLM_DEVICE = "cuda" +LLM_DEVICE = "auto" HISTORY_LEN = 3 @@ -66,7 +66,7 @@ ONLINE_LLM_MODEL = { "APPID": "", "APISecret": "", "api_key": "", - "version": "v1.5", # 你使用的讯飞星火大模型版本,可选包括 "v3.0", "v1.5", "v2.0" + "version": "v3.0", # 你使用的讯飞星火大模型版本,可选包括 "v3.0", "v2.0", "v1.5" "provider": "XingHuoWorker", }, @@ -120,7 +120,7 @@ ONLINE_LLM_MODEL = { "secret_key": "", "provider": "TianGongWorker", }, - # Gemini API (开发组未测试,由社群提供,只支持pro)https://makersuite.google.com/或者google cloud,使用前先确认网络正常,使用代理请在项目启动(python startup.py -a)环境内设置https_proxy环境变量 + # Gemini API (开发组未测试,由社群提供,只支持pro) "gemini-api": { "api_key": "", "provider": "GeminiWorker", @@ -155,7 +155,7 @@ MODEL_PATH = { "bge-large-zh": "BAAI/bge-large-zh", "bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct", "bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5", - "bge-large-zh-v1.5": "/share/home/zyx/Models/bge-large-zh-v1.5", + "bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5", "piccolo-base-zh": "sensenova/piccolo-base-zh", "piccolo-large-zh": "sensenova/piccolo-large-zh", "nlp_gte_sentence-embedding_chinese-large": "damo/nlp_gte_sentence-embedding_chinese-large", @@ -168,11 +168,15 @@ MODEL_PATH = { "chatglm3-6b": "THUDM/chatglm3-6b", "chatglm3-6b-32k": "THUDM/chatglm3-6b-32k", + "Orion-14B-Chat": "OrionStarAI/Orion-14B-Chat", + "Orion-14B-Chat-Plugin": "OrionStarAI/Orion-14B-Chat-Plugin", + "Orion-14B-LongChat": "OrionStarAI/Orion-14B-LongChat", + "Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf", "Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf", "Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf", - "Qwen-1_8B-Chat": "/media/checkpoint/Qwen-1_8B-Chat", + "Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat", "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", "Qwen-72B-Chat": "Qwen/Qwen-72B-Chat", @@ -197,7 +201,7 @@ MODEL_PATH = { "agentlm-70b": "THUDM/agentlm-70b", "falcon-7b": "tiiuae/falcon-7b", - "falcon-40b": "tiiuae/falcon-40b", + "falcon-40b": "tiiuae/falcon-40,b", "falcon-rw-7b": "tiiuae/falcon-rw-7b", "aquila-7b": "BAAI/Aquila-7B", @@ -287,9 +291,11 @@ VLLM_MODEL_DICT = { } SUPPORT_AGENT_MODEL = [ - "azure-api", - "openai-api", - "qwen-api", - "Qwen", - "chatglm3", + "openai-api", # GPT4 模型 + "qwen-api", # Qwen Max模型 + "zhipu-api", # 智谱AI GLM4模型 + "Qwen", # 所有Qwen系列本地模型 + "chatglm3-6b", + "internlm2-chat-20b", + "Orion-14B-Chat-Plugin", ] diff --git a/requirements.txt b/requirements.txt index db1cbe282..3484ba4e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -40,6 +40,9 @@ einops>=0.7.0 transformers_stream_generator==0.0.4 vllm==0.2.7; sys_platform == "linux" +# flash-attn>=2.4.3 # For Orion-14B-Chat and Qwen-14B-Chat + + # optional document loaders #rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files diff --git a/requirements_api.txt b/requirements_api.txt index 60b884ebe..e126c8c58 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -39,6 +39,7 @@ transformers_stream_generator==0.0.4 vllm==0.2.7; sys_platform == "linux" httpx==0.26.0 llama-index +# flash-attn>=2.4.3 # For Orion-14B-Chat and Qwen-14B-Chat # optional document loaders diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py index f47958a05..41bf5baba 100644 --- a/server/chat/agent_chat.py +++ b/server/chat/agent_chat.py @@ -19,6 +19,7 @@ from server.agent import model_container from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate + async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), history: List[History] = Body([], description="历史对话", @@ -82,8 +83,7 @@ async def agent_chat_iterator( memory.chat_memory.add_user_message(message.content) else: memory.chat_memory.add_ai_message(message.content) - - if "chatglm3" in model_container.MODEL.model_name: + if "chatglm3" in model_container.MODEL.model_name or "zhipu-api" in model_container.MODEL.model_name: agent_executor = initialize_glm3_agent( llm=model, tools=tools, diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index 552b67cca..0ed3bdcdb 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -57,21 +57,24 @@ def do_chat(self, params: ApiChatParams) -> Iterator[Dict]: "messages": params.messages, "max_tokens": params.max_tokens, "temperature": params.temperature, - "stream": True + "stream": False } url = "https://open.bigmodel.cn/api/paas/v4/chat/completions" - response = requests.post(url, headers=headers, json=data, stream=True) - for chunk in response.iter_lines(): - if chunk: - chunk_str = chunk.decode('utf-8') - json_start_pos = chunk_str.find('{"id"') - if json_start_pos != -1: - json_str = chunk_str[json_start_pos:] - json_data = json.loads(json_str) - for choice in json_data.get('choices', []): - delta = choice.get('delta', {}) - content = delta.get('content', '') - yield {"error_code": 0, "text": content} + response = requests.post(url, headers=headers, json=data) + # for chunk in response.iter_lines(): + # if chunk: + # chunk_str = chunk.decode('utf-8') + # json_start_pos = chunk_str.find('{"id"') + # if json_start_pos != -1: + # json_str = chunk_str[json_start_pos:] + # json_data = json.loads(json_str) + # for choice in json_data.get('choices', []): + # delta = choice.get('delta', {}) + # content = delta.get('content', '') + # yield {"error_code": 0, "text": content} + ans = response.json() + content = ans["choices"][0]["message"]["content"] + yield {"error_code": 0, "text": content} def get_embeddings(self, params): # 临时解决方案,不支持embedding