Skip to content

Commit

Permalink
给 ApiModelWorker 添加 logger 成员变量,API请求出错时输出有意义的错误信息。 (#2169)
Browse files Browse the repository at this point in the history
* 给 ApiModelWorker 添加 logger 成员变量,API请求出错时输出有意义的错误信息。
/chat/chat 接口 conversation_id参数改为默认 "",避免 swagger 页面默认值错误导致历史消息失效

* 修复在线模型一些bug
  • Loading branch information
liunux4odoo committed Nov 25, 2023
1 parent 1b0cf67 commit 1de4258
Show file tree
Hide file tree
Showing 12 changed files with 85 additions and 58 deletions.
4 changes: 2 additions & 2 deletions server/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
conversation_id: str = Body(None, description="对话框ID"),
conversation_id: str = Body("", description="对话框ID"),
history: Union[int, List[History]] = Body([],
description="历史对话,设为一个整数可以从数据库中读取历史消息",
examples=[[
Expand Down Expand Up @@ -54,7 +54,7 @@ async def chat_iterator() -> AsyncIterable[str]:
callbacks=callbacks,
)

if conversation_id is None:
if not conversation_id:
history = [History.from_data(h) for h in history]
prompt_template = get_prompt_template("llm_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
Expand Down
6 changes: 4 additions & 2 deletions server/model_workers/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class AzureWorker(ApiModelWorker):
def __init__(
self,
*,
controller_addr: str,
worker_addr: str,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["azure-api"],
version: str = "gpt-35-turbo",
**kwargs,
Expand Down Expand Up @@ -60,6 +60,8 @@ def do_chat(self, params: ApiChatParams) -> Dict:
"error_code": 0,
"text": text
}
else:
self.logger.error(f"请求 Azure API 时发生错误:{resp}")

def get_embeddings(self, params):
# TODO: 支持embeddings
Expand Down
4 changes: 3 additions & 1 deletion server/model_workers/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def do_chat(self, params: ApiChatParams) -> Dict:
"text": text
}
else:
yield {
data = {
"error_code": resp["code"],
"text": resp["msg"],
"error": {
Expand All @@ -84,6 +84,8 @@ def do_chat(self, params: ApiChatParams) -> Dict:
"code": None,
}
}
self.logger.error(f"请求百川 API 时发生错误:{data}")
yield data

def get_embeddings(self, params):
# TODO: 支持embeddings
Expand Down
6 changes: 2 additions & 4 deletions server/model_workers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@

__all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"]

# 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__


class ApiConfigParams(BaseModel):
'''
Expand Down Expand Up @@ -110,7 +106,9 @@ def __init__(
controller_addr=controller_addr,
worker_addr=worker_addr,
**kwargs)
import fastchat.serve.base_model_worker
import sys
self.logger = fastchat.serve.base_model_worker.logger
# 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
Expand Down
42 changes: 26 additions & 16 deletions server/model_workers/fangzhou.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,33 @@ def do_chat(self, params: ApiChatParams) -> Dict:

text = ""
if log_verbose:
logger.info(f'{self.__class__.__name__}:maas: {maas}')
self.logger.info(f'{self.__class__.__name__}:maas: {maas}')
for resp in maas.stream_chat(req):
error = resp.error
if error.code_n > 0:
yield {
"error_code": error.code_n,
"text": error.message,
"error": {
"message": error.message,
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
elif chunk := resp.choice.message.content:
text += chunk
yield {"error_code": 0, "text": text}
if error := resp.error:
if error.code_n > 0:
data = {
"error_code": error.code_n,
"text": error.message,
"error": {
"message": error.message,
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求方舟 API 时发生错误:{data}")
yield data
elif chunk := resp.choice.message.content:
text += chunk
yield {"error_code": 0, "text": text}
else:
data = {
"error_code": 500,
"text": f"请求方舟 API 时发生未知的错误: {resp}"
}
self.logger.error(data)
yield data
break

def get_embeddings(self, params):
# TODO: 支持embeddings
Expand Down
8 changes: 6 additions & 2 deletions server/model_workers/minimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def do_chat(self, params: ApiChatParams) -> Dict:
text = ""
for e in r.iter_text():
if not e.startswith("data: "): # 真是优秀的返回
yield {
data = {
"error_code": 500,
"text": f"minimax返回错误的结果:{e}",
"error": {
Expand All @@ -84,6 +84,8 @@ def do_chat(self, params: ApiChatParams) -> Dict:
"code": None,
}
}
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
yield data
continue

data = json.loads(e[6:])
Expand Down Expand Up @@ -123,7 +125,7 @@ def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
if embeddings := r.get("vectors"):
result += embeddings
elif error := r.get("base_resp"):
return {
data = {
"code": error["status_code"],
"msg": error["status_msg"],
"error": {
Expand All @@ -133,6 +135,8 @@ def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
"code": None,
}
}
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
return data
i += 10
return {"code": 200, "data": embeddings}

Expand Down
8 changes: 6 additions & 2 deletions server/model_workers/qianfan.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def do_chat(self, params: ApiChatParams) -> Dict:
"text": text
}
else:
yield {
data = {
"error_code": resp["error_code"],
"text": resp["error_msg"],
"error": {
Expand All @@ -164,6 +164,8 @@ def do_chat(self, params: ApiChatParams) -> Dict:
"code": None,
}
}
self.logger.error(f"请求千帆 API 时发生错误:{data}")
yield data

def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
params.load_config(self.model_names[0])
Expand All @@ -189,7 +191,7 @@ def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
for texts in params.texts[i:i+10]:
resp = client.post(url, json={"input": texts}).json()
if "error_cdoe" in resp:
return {
data = {
"code": resp["error_code"],
"msg": resp["error_msg"],
"error": {
Expand All @@ -199,6 +201,8 @@ def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
"code": None,
}
}
self.logger.error(f"请求千帆 API 时发生错误:{data}")
return data
else:
embeddings = [x["embedding"] for x in resp.get("data", [])]
result += embeddings
Expand Down
9 changes: 6 additions & 3 deletions server/model_workers/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def do_chat(self, params: ApiChatParams) -> Dict:
"text": choices[0]["message"]["content"],
}
else:
yield {
data = {
"error_code": resp["status_code"],
"text": resp["message"],
"error": {
Expand All @@ -63,7 +63,8 @@ def do_chat(self, params: ApiChatParams) -> Dict:
"code": None,
}
}

self.logger.error(f"请求千问 API 时发生错误:{data}")
yield data

def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
import dashscope
Expand All @@ -80,7 +81,7 @@ def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
api_key=params.api_key,
)
if resp["status_code"] != 200:
return {
data = {
"code": resp["status_code"],
"msg": resp.message,
"error": {
Expand All @@ -90,6 +91,8 @@ def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
"code": None,
}
}
self.logger.error(f"请求千问 API 时发生错误:{data}")
return data
else:
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
result += embeddings
Expand Down
9 changes: 4 additions & 5 deletions server/model_workers/tiangong.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from server.model_workers.base import *
from server.utils import get_httpx_client
from fastchat import conversation as conv
import sys
import json
from typing import List, Literal, Dict
import requests
Expand Down Expand Up @@ -64,12 +63,12 @@ def do_chat(self, params: ApiChatParams) -> Dict:
"text": text
}
else:
yield {
data = {
"error_code": resp["code"],
"text": resp["code_msg"]
}


}
self.logger.error(f"请求天工 API 时出错:{data}")
yield data

def get_embeddings(self, params):
# TODO: 支持embeddings
Expand Down
2 changes: 1 addition & 1 deletion server/model_workers/xinghuo.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_version_details(version_key):
loop = asyncio.get_event_loop()
except:
loop = asyncio.new_event_loop()
params.max_tokens = min(details["max_tokens"], params.max_tokens)
params.max_tokens = min(details["max_tokens"], params.max_tokens or 0)
for chunk in iter_over_async(
request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages,
params.temperature, params.max_tokens),
Expand Down
9 changes: 7 additions & 2 deletions server/model_workers/zhipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def do_chat(self, params: ApiChatParams) -> Iterator[Dict]:
if e.event == "add":
yield {"error_code": 0, "text": e.data}
elif e.event in ["error", "interrupted"]:
yield {
data = {
"error_code": 500,
"text": str(e),
"error": {
Expand All @@ -54,6 +54,8 @@ def do_chat(self, params: ApiChatParams) -> Iterator[Dict]:
"code": None,
}
}
self.logger.error(f"请求智谱 API 时发生错误:{data}")
yield data

def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
import zhipuai
Expand All @@ -68,9 +70,12 @@ def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
if response["code"] == 200:
embeddings.append(response["data"]["embedding"])
else:
self.logger.error(f"请求智谱 API 时发生错误:{response}")
return response # dict with code & msg
except Exception as e:
return {"code": 500, "msg": f"对文本向量化时出错:{e}"}
self.logger.error(f"请求智谱 API 时发生错误:{data}")
data = {"code": 500, "msg": f"对文本向量化时出错:{e}"}
return data

return {"code": 200, "data": embeddings}

Expand Down
36 changes: 18 additions & 18 deletions tests/test_online_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
workers.append(x)
print(f"all workers to test: {workers}")

# workers = ["qianfan-api"]
# workers = ["fangzhou-api"]


@pytest.mark.parametrize("worker", workers)
Expand All @@ -28,11 +28,11 @@ def test_chat(worker):
)
print(f"\nchat with {worker} \n")

worker_class = get_model_worker_config(worker)["worker_class"]
for x in worker_class().do_chat(params):
pprint(x)
assert isinstance(x, dict)
assert x["error_code"] == 0
if worker_class := get_model_worker_config(worker).get("worker_class"):
for x in worker_class().do_chat(params):
pprint(x)
assert isinstance(x, dict)
assert x["error_code"] == 0


@pytest.mark.parametrize("worker", workers)
Expand All @@ -44,19 +44,19 @@ def test_embeddings(worker):
]
)

worker_class = get_model_worker_config(worker)["worker_class"]
if worker_class.can_embedding():
print(f"\embeddings with {worker} \n")
resp = worker_class().do_embeddings(params)
if worker_class := get_model_worker_config(worker).get("worker_class"):
if worker_class.can_embedding():
print(f"\embeddings with {worker} \n")
resp = worker_class().do_embeddings(params)

pprint(resp, depth=2)
assert resp["code"] == 200
assert "data" in resp
embeddings = resp["data"]
assert isinstance(embeddings, list) and len(embeddings) > 0
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
assert isinstance(embeddings[0][0], float)
print("向量长度:", len(embeddings[0]))
pprint(resp, depth=2)
assert resp["code"] == 200
assert "data" in resp
embeddings = resp["data"]
assert isinstance(embeddings, list) and len(embeddings) > 0
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
assert isinstance(embeddings[0][0], float)
print("向量长度:", len(embeddings[0]))


# @pytest.mark.parametrize("worker", workers)
Expand Down

0 comments on commit 1de4258

Please sign in to comment.