Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

配置文件修改 #2749

Merged
merged 4 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions configs/model_config.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ MODEL_ROOT_PATH = ""
# 选用的 Embedding 名称
EMBEDDING_MODEL = "bge-large-zh-v1.5"

# Embedding 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
# Embedding 模型运行设备。设为 "auto" 会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
EMBEDDING_DEVICE = "auto"

# 选用的reranker模型
Expand All @@ -26,11 +26,11 @@ EMBEDDING_MODEL_OUTPUT_PATH = "output"
# 在这里,我们使用目前主流的两个离线模型,其中,chatglm3-6b 为默认加载模型。
# 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。

LLM_MODELS = ["zhipu-api"]
LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"]
Agent_MODEL = None

# LLM 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
LLM_DEVICE = "cuda"
LLM_DEVICE = "auto"

HISTORY_LEN = 3

Expand Down Expand Up @@ -66,7 +66,7 @@ ONLINE_LLM_MODEL = {
"APPID": "",
"APISecret": "",
"api_key": "",
"version": "v1.5", # 你使用的讯飞星火大模型版本,可选包括 "v3.0", "v1.5", "v2.0"
"version": "v3.0", # 你使用的讯飞星火大模型版本,可选包括 "v3.0", "v2.0", "v1.5"
"provider": "XingHuoWorker",
},

Expand Down Expand Up @@ -120,7 +120,7 @@ ONLINE_LLM_MODEL = {
"secret_key": "",
"provider": "TianGongWorker",
},
# Gemini API (开发组未测试,由社群提供,只支持pro)https://makersuite.google.com/或者google cloud,使用前先确认网络正常,使用代理请在项目启动(python startup.py -a)环境内设置https_proxy环境变量
# Gemini API (开发组未测试,由社群提供,只支持pro)
"gemini-api": {
"api_key": "",
"provider": "GeminiWorker",
Expand Down Expand Up @@ -155,7 +155,7 @@ MODEL_PATH = {
"bge-large-zh": "BAAI/bge-large-zh",
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
"bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5",
"bge-large-zh-v1.5": "/share/home/zyx/Models/bge-large-zh-v1.5",
"bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5",
"piccolo-base-zh": "sensenova/piccolo-base-zh",
"piccolo-large-zh": "sensenova/piccolo-large-zh",
"nlp_gte_sentence-embedding_chinese-large": "damo/nlp_gte_sentence-embedding_chinese-large",
Expand All @@ -168,11 +168,15 @@ MODEL_PATH = {
"chatglm3-6b": "THUDM/chatglm3-6b",
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",

"Orion-14B-Chat": "OrionStarAI/Orion-14B-Chat",
"Orion-14B-Chat-Plugin": "OrionStarAI/Orion-14B-Chat-Plugin",
"Orion-14B-LongChat": "OrionStarAI/Orion-14B-LongChat",

"Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
"Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
"Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf",

"Qwen-1_8B-Chat": "/media/checkpoint/Qwen-1_8B-Chat",
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
Expand All @@ -197,7 +201,7 @@ MODEL_PATH = {
"agentlm-70b": "THUDM/agentlm-70b",

"falcon-7b": "tiiuae/falcon-7b",
"falcon-40b": "tiiuae/falcon-40b",
"falcon-40b": "tiiuae/falcon-40,b",
"falcon-rw-7b": "tiiuae/falcon-rw-7b",

"aquila-7b": "BAAI/Aquila-7B",
Expand Down Expand Up @@ -287,9 +291,11 @@ VLLM_MODEL_DICT = {
}

SUPPORT_AGENT_MODEL = [
"azure-api",
"openai-api",
"qwen-api",
"Qwen",
"chatglm3",
"openai-api", # GPT4 模型
"qwen-api", # Qwen Max模型
"zhipu-api", # 智谱AI GLM4模型
"Qwen", # 所有Qwen系列本地模型
"chatglm3-6b",
"internlm2-chat-20b",
"Orion-14B-Chat-Plugin",
]
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ einops>=0.7.0
transformers_stream_generator==0.0.4
vllm==0.2.7; sys_platform == "linux"

# flash-attn>=2.4.3 # For Orion-14B-Chat and Qwen-14B-Chat


# optional document loaders

#rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
Expand Down
1 change: 1 addition & 0 deletions requirements_api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ transformers_stream_generator==0.0.4
vllm==0.2.7; sys_platform == "linux"
httpx==0.26.0
llama-index
# flash-attn>=2.4.3 # For Orion-14B-Chat and Qwen-14B-Chat

# optional document loaders

Expand Down
4 changes: 2 additions & 2 deletions server/chat/agent_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from server.agent import model_container
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate


async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
history: List[History] = Body([],
description="历史对话",
Expand Down Expand Up @@ -82,8 +83,7 @@ async def agent_chat_iterator(
memory.chat_memory.add_user_message(message.content)
else:
memory.chat_memory.add_ai_message(message.content)

if "chatglm3" in model_container.MODEL.model_name:
if "chatglm3" in model_container.MODEL.model_name or "zhipu-api" in model_container.MODEL.model_name:
agent_executor = initialize_glm3_agent(
llm=model,
tools=tools,
Expand Down
29 changes: 16 additions & 13 deletions server/model_workers/zhipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,24 @@ def do_chat(self, params: ApiChatParams) -> Iterator[Dict]:
"messages": params.messages,
"max_tokens": params.max_tokens,
"temperature": params.temperature,
"stream": True
"stream": False
}
url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
response = requests.post(url, headers=headers, json=data, stream=True)
for chunk in response.iter_lines():
if chunk:
chunk_str = chunk.decode('utf-8')
json_start_pos = chunk_str.find('{"id"')
if json_start_pos != -1:
json_str = chunk_str[json_start_pos:]
json_data = json.loads(json_str)
for choice in json_data.get('choices', []):
delta = choice.get('delta', {})
content = delta.get('content', '')
yield {"error_code": 0, "text": content}
response = requests.post(url, headers=headers, json=data)
# for chunk in response.iter_lines():
# if chunk:
# chunk_str = chunk.decode('utf-8')
# json_start_pos = chunk_str.find('{"id"')
# if json_start_pos != -1:
# json_str = chunk_str[json_start_pos:]
# json_data = json.loads(json_str)
# for choice in json_data.get('choices', []):
# delta = choice.get('delta', {})
# content = delta.get('content', '')
# yield {"error_code": 0, "text": content}
ans = response.json()
content = ans["choices"][0]["message"]["content"]
yield {"error_code": 0, "text": content}

def get_embeddings(self, params):
# 临时解决方案,不支持embedding
Expand Down