Skip to content

Commit

Permalink
Merge pull request #2749 from zRzRzRzRzRzRzR/dev
Browse files Browse the repository at this point in the history
配置文件修改
  • Loading branch information
zRzRzRzRzRzRzR committed Jan 22, 2024
2 parents c0968fb + 80c26e4 commit fb6c84b
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 28 deletions.
32 changes: 19 additions & 13 deletions configs/model_config.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ MODEL_ROOT_PATH = ""
# 选用的 Embedding 名称
EMBEDDING_MODEL = "bge-large-zh-v1.5"

# Embedding 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
# Embedding 模型运行设备。设为 "auto" 会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
EMBEDDING_DEVICE = "auto"

# 选用的reranker模型
Expand All @@ -26,11 +26,11 @@ EMBEDDING_MODEL_OUTPUT_PATH = "output"
# 在这里,我们使用目前主流的两个离线模型,其中,chatglm3-6b 为默认加载模型。
# 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。

LLM_MODELS = ["zhipu-api"]
LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"]
Agent_MODEL = None

# LLM 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
LLM_DEVICE = "cuda"
LLM_DEVICE = "auto"

HISTORY_LEN = 3

Expand Down Expand Up @@ -66,7 +66,7 @@ ONLINE_LLM_MODEL = {
"APPID": "",
"APISecret": "",
"api_key": "",
"version": "v1.5", # 你使用的讯飞星火大模型版本,可选包括 "v3.0", "v1.5", "v2.0"
"version": "v3.0", # 你使用的讯飞星火大模型版本,可选包括 "v3.0", "v2.0", "v1.5"
"provider": "XingHuoWorker",
},

Expand Down Expand Up @@ -120,7 +120,7 @@ ONLINE_LLM_MODEL = {
"secret_key": "",
"provider": "TianGongWorker",
},
# Gemini API (开发组未测试,由社群提供,只支持pro)https://makersuite.google.com/或者google cloud,使用前先确认网络正常,使用代理请在项目启动(python startup.py -a)环境内设置https_proxy环境变量
# Gemini API (开发组未测试,由社群提供,只支持pro)
"gemini-api": {
"api_key": "",
"provider": "GeminiWorker",
Expand Down Expand Up @@ -155,7 +155,7 @@ MODEL_PATH = {
"bge-large-zh": "BAAI/bge-large-zh",
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
"bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5",
"bge-large-zh-v1.5": "/share/home/zyx/Models/bge-large-zh-v1.5",
"bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5",
"piccolo-base-zh": "sensenova/piccolo-base-zh",
"piccolo-large-zh": "sensenova/piccolo-large-zh",
"nlp_gte_sentence-embedding_chinese-large": "damo/nlp_gte_sentence-embedding_chinese-large",
Expand All @@ -168,11 +168,15 @@ MODEL_PATH = {
"chatglm3-6b": "THUDM/chatglm3-6b",
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",

"Orion-14B-Chat": "OrionStarAI/Orion-14B-Chat",
"Orion-14B-Chat-Plugin": "OrionStarAI/Orion-14B-Chat-Plugin",
"Orion-14B-LongChat": "OrionStarAI/Orion-14B-LongChat",

"Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
"Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
"Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf",

"Qwen-1_8B-Chat": "/media/checkpoint/Qwen-1_8B-Chat",
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
Expand All @@ -197,7 +201,7 @@ MODEL_PATH = {
"agentlm-70b": "THUDM/agentlm-70b",

"falcon-7b": "tiiuae/falcon-7b",
"falcon-40b": "tiiuae/falcon-40b",
"falcon-40b": "tiiuae/falcon-40,b",
"falcon-rw-7b": "tiiuae/falcon-rw-7b",

"aquila-7b": "BAAI/Aquila-7B",
Expand Down Expand Up @@ -287,9 +291,11 @@ VLLM_MODEL_DICT = {
}

SUPPORT_AGENT_MODEL = [
"azure-api",
"openai-api",
"qwen-api",
"Qwen",
"chatglm3",
"openai-api", # GPT4 模型
"qwen-api", # Qwen Max模型
"zhipu-api", # 智谱AI GLM4模型
"Qwen", # 所有Qwen系列本地模型
"chatglm3-6b",
"internlm2-chat-20b",
"Orion-14B-Chat-Plugin",
]
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ einops>=0.7.0
transformers_stream_generator==0.0.4
vllm==0.2.7; sys_platform == "linux"

# flash-attn>=2.4.3 # For Orion-14B-Chat and Qwen-14B-Chat


# optional document loaders

#rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
Expand Down
1 change: 1 addition & 0 deletions requirements_api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ transformers_stream_generator==0.0.4
vllm==0.2.7; sys_platform == "linux"
httpx==0.26.0
llama-index
# flash-attn>=2.4.3 # For Orion-14B-Chat and Qwen-14B-Chat

# optional document loaders

Expand Down
4 changes: 2 additions & 2 deletions server/chat/agent_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from server.agent import model_container
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate


async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
history: List[History] = Body([],
description="历史对话",
Expand Down Expand Up @@ -82,8 +83,7 @@ async def agent_chat_iterator(
memory.chat_memory.add_user_message(message.content)
else:
memory.chat_memory.add_ai_message(message.content)

if "chatglm3" in model_container.MODEL.model_name:
if "chatglm3" in model_container.MODEL.model_name or "zhipu-api" in model_container.MODEL.model_name:
agent_executor = initialize_glm3_agent(
llm=model,
tools=tools,
Expand Down
29 changes: 16 additions & 13 deletions server/model_workers/zhipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,24 @@ def do_chat(self, params: ApiChatParams) -> Iterator[Dict]:
"messages": params.messages,
"max_tokens": params.max_tokens,
"temperature": params.temperature,
"stream": True
"stream": False
}
url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
response = requests.post(url, headers=headers, json=data, stream=True)
for chunk in response.iter_lines():
if chunk:
chunk_str = chunk.decode('utf-8')
json_start_pos = chunk_str.find('{"id"')
if json_start_pos != -1:
json_str = chunk_str[json_start_pos:]
json_data = json.loads(json_str)
for choice in json_data.get('choices', []):
delta = choice.get('delta', {})
content = delta.get('content', '')
yield {"error_code": 0, "text": content}
response = requests.post(url, headers=headers, json=data)
# for chunk in response.iter_lines():
# if chunk:
# chunk_str = chunk.decode('utf-8')
# json_start_pos = chunk_str.find('{"id"')
# if json_start_pos != -1:
# json_str = chunk_str[json_start_pos:]
# json_data = json.loads(json_str)
# for choice in json_data.get('choices', []):
# delta = choice.get('delta', {})
# content = delta.get('content', '')
# yield {"error_code": 0, "text": content}
ans = response.json()
content = ans["choices"][0]["message"]["content"]
yield {"error_code": 0, "text": content}

def get_embeddings(self, params):
# 临时解决方案,不支持embedding
Expand Down

0 comments on commit fb6c84b

Please sign in to comment.